Config Schema¤
Validate configuration against schemas.
See Also¤
- Config Overview - All config tools
- Loaders - Config file loaders
- Registry - Component registry
- Troubleshooting
datarax.config.schema ¤
Configuration schema definition and validation.
This module provides classes and utilities for defining and validating configuration schemas for Datarax pipelines and components.
The Datarax configuration system uses TOML for defining pipeline configurations and supports configuration of NNX-specific features, including:
- RNG streams for stochastic operations
- State persistence for stateful components
- NNX module configuration options
SchemaField
dataclass
¤
SchemaField(type: SchemaType, required: bool = True, default: Any = None, validator: Callable[[Any], bool] | None = None, description: str | None = None)
Definition of a field in a configuration schema.
Attributes:
| Name | Type | Description |
|---|---|---|
type |
SchemaType
|
The expected type of the field |
required |
bool
|
Whether the field is required |
default |
Any
|
Default value for the field if not specified |
validator |
Callable[[Any], bool] | None
|
Optional function to validate the field value |
description |
str | None
|
Optional description of the field |
ConfigSchema ¤
Base class for configuration schemas.
Subclasses should define schema fields as class variables using SchemaField.
Examples:
class MyConfigSchema(ConfigSchema):
name: SchemaField = SchemaField(str, required=True)
count: SchemaField = SchemaField(int, required=False, default=0)
get_schema_fields
classmethod
¤
get_schema_fields() -> dict[str, SchemaField]
Get all schema fields defined in the class.
Returns:
| Type | Description |
|---|---|
dict[str, SchemaField]
|
Dictionary mapping field names to SchemaField instances |
validate
classmethod
¤
Validate a configuration dictionary against the schema.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, Any]
|
The configuration dictionary to validate |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A validated configuration dictionary with defaults applied |
Raises:
| Type | Description |
|---|---|
ValidationError
|
If the configuration fails validation |
create
classmethod
¤
Create a validated configuration dictionary from the schema.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, Any]
|
The configuration dictionary to validate |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A validated configuration dictionary with defaults applied |
Raises:
| Type | Description |
|---|---|
ValidationError
|
If the configuration fails validation |
PipelineSchema
dataclass
¤
PipelineSchema(name: SchemaField = (lambda: SchemaField(str, required=True, description='Name of the pipeline'))(), description: SchemaField = (lambda: SchemaField(str, required=False, default='', description='Description of the pipeline'))(), version: SchemaField = (lambda: SchemaField(str, required=False, default='0.1.0', description='Version of the pipeline configuration'))(), sources: SchemaField = (lambda: SchemaField(dict, required=True, description='Data source configurations'))(), transforms: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='Data transformation configurations'))(), augmenters: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='Data augmentation configurations'))(), samplers: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='Data sampling configurations'))(), batch_size: SchemaField = (lambda: SchemaField(int, required=False, default=32, description='Default batch size for the pipeline'))(), random_seed: SchemaField = (lambda: SchemaField(int, required=False, default=42, description='Random seed for reproducibility'))(), rng_streams: SchemaField = (lambda: SchemaField(dict, required=False, default={'default': 42, 'augment': 43, 'dropout': 44}, description='RNG streams for NNX components with their seed values'))(), checkpointing: SchemaField = (lambda: SchemaField(dict, required=False, default={'enabled': False, 'directory': 'checkpoints', 'frequency': 1000}, description='Checkpointing configuration for saving pipeline state'))(), device_mesh: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='JAX device mesh configuration for distributed training'))())
Bases: ConfigSchema
Schema for pipeline configuration.
This schema defines the structure of a pipeline configuration file, including data sources, transformers, augmenters, and other components.
name
class-attribute
instance-attribute
¤
name: SchemaField = field(default_factory=lambda: SchemaField(str, required=True, description='Name of the pipeline'))
description
class-attribute
instance-attribute
¤
description: SchemaField = field(default_factory=lambda: SchemaField(str, required=False, default='', description='Description of the pipeline'))
version
class-attribute
instance-attribute
¤
version: SchemaField = field(default_factory=lambda: SchemaField(str, required=False, default='0.1.0', description='Version of the pipeline configuration'))
sources
class-attribute
instance-attribute
¤
sources: SchemaField = field(default_factory=lambda: SchemaField(dict, required=True, description='Data source configurations'))
transforms
class-attribute
instance-attribute
¤
transforms: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='Data transformation configurations'))
augmenters
class-attribute
instance-attribute
¤
augmenters: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='Data augmentation configurations'))
samplers
class-attribute
instance-attribute
¤
samplers: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='Data sampling configurations'))
batch_size
class-attribute
instance-attribute
¤
batch_size: SchemaField = field(default_factory=lambda: SchemaField(int, required=False, default=32, description='Default batch size for the pipeline'))
random_seed
class-attribute
instance-attribute
¤
random_seed: SchemaField = field(default_factory=lambda: SchemaField(int, required=False, default=42, description='Random seed for reproducibility'))
rng_streams
class-attribute
instance-attribute
¤
rng_streams: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={'default': 42, 'augment': 43, 'dropout': 44}, description='RNG streams for NNX components with their seed values'))
checkpointing
class-attribute
instance-attribute
¤
checkpointing: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={'enabled': False, 'directory': 'checkpoints', 'frequency': 1000}, description='Checkpointing configuration for saving pipeline state'))
device_mesh
class-attribute
instance-attribute
¤
device_mesh: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='JAX device mesh configuration for distributed training'))
get_schema_fields
classmethod
¤
get_schema_fields() -> dict[str, SchemaField]
Get all schema fields defined in the class.
Returns:
| Type | Description |
|---|---|
dict[str, SchemaField]
|
Dictionary mapping field names to SchemaField instances |
validate
classmethod
¤
Validate a configuration dictionary against the schema.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, Any]
|
The configuration dictionary to validate |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A validated configuration dictionary with defaults applied |
Raises:
| Type | Description |
|---|---|
ValidationError
|
If the configuration fails validation |
create
classmethod
¤
Create a validated configuration dictionary from the schema.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, Any]
|
The configuration dictionary to validate |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A validated configuration dictionary with defaults applied |
Raises:
| Type | Description |
|---|---|
ValidationError
|
If the configuration fails validation |
NNXComponentSchema
dataclass
¤
NNXComponentSchema(type: SchemaField = (lambda: SchemaField(str, required=True, description='The type of NNX component to create'))(), params: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='Parameters to pass to the component constructor'))(), variables: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='Initial values for NNX variables in the component'))(), rngs: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='RNG stream configurations specific to this component'))(), load_state_from: SchemaField = (lambda: SchemaField(str, required=False, default=None, description='Path to load initial component state from'))())
Bases: ConfigSchema
Schema for configuring components that use NNX modules.
This schema defines the structure for components that leverage Flax NNX for state management and computation.
type
class-attribute
instance-attribute
¤
type: SchemaField = field(default_factory=lambda: SchemaField(str, required=True, description='The type of NNX component to create'))
params
class-attribute
instance-attribute
¤
params: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='Parameters to pass to the component constructor'))
variables
class-attribute
instance-attribute
¤
variables: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='Initial values for NNX variables in the component'))
rngs
class-attribute
instance-attribute
¤
rngs: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='RNG stream configurations specific to this component'))
load_state_from
class-attribute
instance-attribute
¤
load_state_from: SchemaField = field(default_factory=lambda: SchemaField(str, required=False, default=None, description='Path to load initial component state from'))
get_schema_fields
classmethod
¤
get_schema_fields() -> dict[str, SchemaField]
Get all schema fields defined in the class.
Returns:
| Type | Description |
|---|---|
dict[str, SchemaField]
|
Dictionary mapping field names to SchemaField instances |
validate
classmethod
¤
Validate a configuration dictionary against the schema.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, Any]
|
The configuration dictionary to validate |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A validated configuration dictionary with defaults applied |
Raises:
| Type | Description |
|---|---|
ValidationError
|
If the configuration fails validation |
create
classmethod
¤
Create a validated configuration dictionary from the schema.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, Any]
|
The configuration dictionary to validate |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A validated configuration dictionary with defaults applied |
Raises:
| Type | Description |
|---|---|
ValidationError
|
If the configuration fails validation |
is_schema_type_valid ¤
is_schema_type_valid(value: Any, expected_type: SchemaType) -> bool
Validate that a value matches an expected schema type.