Skip to content

Config Schema¤

Validate configuration against schemas.

See Also¤


datarax.config.schema ¤

Configuration schema definition and validation.

This module provides classes and utilities for defining and validating configuration schemas for Datarax pipelines and components.

The Datarax configuration system uses TOML for defining pipeline configurations and supports configuration of NNX-specific features, including:

  • RNG streams for stochastic operations
  • State persistence for stateful components
  • NNX module configuration options

logger module-attribute ¤

logger = getLogger(__name__)

SchemaType module-attribute ¤

SchemaType = type[Any] | ConfigSchema

ValidationError ¤

Bases: Exception

Exception raised when configuration validation fails.

SchemaField dataclass ¤

SchemaField(type: SchemaType, required: bool = True, default: Any = None, validator: Callable[[Any], bool] | None = None, description: str | None = None)

Definition of a field in a configuration schema.

Attributes:

Name Type Description
type SchemaType

The expected type of the field

required bool

Whether the field is required

default Any

Default value for the field if not specified

validator Callable[[Any], bool] | None

Optional function to validate the field value

description str | None

Optional description of the field

type instance-attribute ¤

type: SchemaType

required class-attribute instance-attribute ¤

required: bool = True

default class-attribute instance-attribute ¤

default: Any = None

validator class-attribute instance-attribute ¤

validator: Callable[[Any], bool] | None = None

description class-attribute instance-attribute ¤

description: str | None = None

ConfigSchema ¤

Base class for configuration schemas.

Subclasses should define schema fields as class variables using SchemaField.

Examples:

class MyConfigSchema(ConfigSchema):
    name: SchemaField = SchemaField(str, required=True)
    count: SchemaField = SchemaField(int, required=False, default=0)

get_schema_fields classmethod ¤

get_schema_fields() -> dict[str, SchemaField]

Get all schema fields defined in the class.

Returns:

Type Description
dict[str, SchemaField]

Dictionary mapping field names to SchemaField instances

validate classmethod ¤

validate(config: dict[str, Any]) -> dict[str, Any]

Validate a configuration dictionary against the schema.

Parameters:

Name Type Description Default
config dict[str, Any]

The configuration dictionary to validate

required

Returns:

Type Description
dict[str, Any]

A validated configuration dictionary with defaults applied

Raises:

Type Description
ValidationError

If the configuration fails validation

create classmethod ¤

create(config: dict[str, Any]) -> dict[str, Any]

Create a validated configuration dictionary from the schema.

Parameters:

Name Type Description Default
config dict[str, Any]

The configuration dictionary to validate

required

Returns:

Type Description
dict[str, Any]

A validated configuration dictionary with defaults applied

Raises:

Type Description
ValidationError

If the configuration fails validation

PipelineSchema dataclass ¤

PipelineSchema(name: SchemaField = (lambda: SchemaField(str, required=True, description='Name of the pipeline'))(), description: SchemaField = (lambda: SchemaField(str, required=False, default='', description='Description of the pipeline'))(), version: SchemaField = (lambda: SchemaField(str, required=False, default='0.1.0', description='Version of the pipeline configuration'))(), sources: SchemaField = (lambda: SchemaField(dict, required=True, description='Data source configurations'))(), transforms: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='Data transformation configurations'))(), augmenters: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='Data augmentation configurations'))(), samplers: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='Data sampling configurations'))(), batch_size: SchemaField = (lambda: SchemaField(int, required=False, default=32, description='Default batch size for the pipeline'))(), random_seed: SchemaField = (lambda: SchemaField(int, required=False, default=42, description='Random seed for reproducibility'))(), rng_streams: SchemaField = (lambda: SchemaField(dict, required=False, default={'default': 42, 'augment': 43, 'dropout': 44}, description='RNG streams for NNX components with their seed values'))(), checkpointing: SchemaField = (lambda: SchemaField(dict, required=False, default={'enabled': False, 'directory': 'checkpoints', 'frequency': 1000}, description='Checkpointing configuration for saving pipeline state'))(), device_mesh: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='JAX device mesh configuration for distributed training'))())

Bases: ConfigSchema

Schema for pipeline configuration.

This schema defines the structure of a pipeline configuration file, including data sources, transformers, augmenters, and other components.

name class-attribute instance-attribute ¤

name: SchemaField = field(default_factory=lambda: SchemaField(str, required=True, description='Name of the pipeline'))

description class-attribute instance-attribute ¤

description: SchemaField = field(default_factory=lambda: SchemaField(str, required=False, default='', description='Description of the pipeline'))

version class-attribute instance-attribute ¤

version: SchemaField = field(default_factory=lambda: SchemaField(str, required=False, default='0.1.0', description='Version of the pipeline configuration'))

sources class-attribute instance-attribute ¤

sources: SchemaField = field(default_factory=lambda: SchemaField(dict, required=True, description='Data source configurations'))

transforms class-attribute instance-attribute ¤

transforms: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='Data transformation configurations'))

augmenters class-attribute instance-attribute ¤

augmenters: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='Data augmentation configurations'))

samplers class-attribute instance-attribute ¤

samplers: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='Data sampling configurations'))

batch_size class-attribute instance-attribute ¤

batch_size: SchemaField = field(default_factory=lambda: SchemaField(int, required=False, default=32, description='Default batch size for the pipeline'))

random_seed class-attribute instance-attribute ¤

random_seed: SchemaField = field(default_factory=lambda: SchemaField(int, required=False, default=42, description='Random seed for reproducibility'))

rng_streams class-attribute instance-attribute ¤

rng_streams: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={'default': 42, 'augment': 43, 'dropout': 44}, description='RNG streams for NNX components with their seed values'))

checkpointing class-attribute instance-attribute ¤

checkpointing: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={'enabled': False, 'directory': 'checkpoints', 'frequency': 1000}, description='Checkpointing configuration for saving pipeline state'))

device_mesh class-attribute instance-attribute ¤

device_mesh: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='JAX device mesh configuration for distributed training'))

get_schema_fields classmethod ¤

get_schema_fields() -> dict[str, SchemaField]

Get all schema fields defined in the class.

Returns:

Type Description
dict[str, SchemaField]

Dictionary mapping field names to SchemaField instances

validate classmethod ¤

validate(config: dict[str, Any]) -> dict[str, Any]

Validate a configuration dictionary against the schema.

Parameters:

Name Type Description Default
config dict[str, Any]

The configuration dictionary to validate

required

Returns:

Type Description
dict[str, Any]

A validated configuration dictionary with defaults applied

Raises:

Type Description
ValidationError

If the configuration fails validation

create classmethod ¤

create(config: dict[str, Any]) -> dict[str, Any]

Create a validated configuration dictionary from the schema.

Parameters:

Name Type Description Default
config dict[str, Any]

The configuration dictionary to validate

required

Returns:

Type Description
dict[str, Any]

A validated configuration dictionary with defaults applied

Raises:

Type Description
ValidationError

If the configuration fails validation

NNXComponentSchema dataclass ¤

NNXComponentSchema(type: SchemaField = (lambda: SchemaField(str, required=True, description='The type of NNX component to create'))(), params: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='Parameters to pass to the component constructor'))(), variables: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='Initial values for NNX variables in the component'))(), rngs: SchemaField = (lambda: SchemaField(dict, required=False, default={}, description='RNG stream configurations specific to this component'))(), load_state_from: SchemaField = (lambda: SchemaField(str, required=False, default=None, description='Path to load initial component state from'))())

Bases: ConfigSchema

Schema for configuring components that use NNX modules.

This schema defines the structure for components that leverage Flax NNX for state management and computation.

type class-attribute instance-attribute ¤

type: SchemaField = field(default_factory=lambda: SchemaField(str, required=True, description='The type of NNX component to create'))

params class-attribute instance-attribute ¤

params: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='Parameters to pass to the component constructor'))

variables class-attribute instance-attribute ¤

variables: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='Initial values for NNX variables in the component'))

rngs class-attribute instance-attribute ¤

rngs: SchemaField = field(default_factory=lambda: SchemaField(dict, required=False, default={}, description='RNG stream configurations specific to this component'))

load_state_from class-attribute instance-attribute ¤

load_state_from: SchemaField = field(default_factory=lambda: SchemaField(str, required=False, default=None, description='Path to load initial component state from'))

get_schema_fields classmethod ¤

get_schema_fields() -> dict[str, SchemaField]

Get all schema fields defined in the class.

Returns:

Type Description
dict[str, SchemaField]

Dictionary mapping field names to SchemaField instances

validate classmethod ¤

validate(config: dict[str, Any]) -> dict[str, Any]

Validate a configuration dictionary against the schema.

Parameters:

Name Type Description Default
config dict[str, Any]

The configuration dictionary to validate

required

Returns:

Type Description
dict[str, Any]

A validated configuration dictionary with defaults applied

Raises:

Type Description
ValidationError

If the configuration fails validation

create classmethod ¤

create(config: dict[str, Any]) -> dict[str, Any]

Create a validated configuration dictionary from the schema.

Parameters:

Name Type Description Default
config dict[str, Any]

The configuration dictionary to validate

required

Returns:

Type Description
dict[str, Any]

A validated configuration dictionary with defaults applied

Raises:

Type Description
ValidationError

If the configuration fails validation

is_schema_type_valid ¤

is_schema_type_valid(value: Any, expected_type: SchemaType) -> bool

Validate that a value matches an expected schema type.