Skip to content

Data Source Protocol¤

Core protocol for data sources.

See Also¤


datarax.core.data_source ¤

Base module for data sources in Datarax.

This module defines the base class for all Datarax data source components that use flax.nnx.Module for state management and JAX transformation compatibility.

logger module-attribute ¤

logger = getLogger(__name__)

LocalFilesOnlyMixin ¤

Adds a uniform local_files_only flag to data sources.

Sources that download external archives (HuggingFace, TFDS, ArrayRecord, etc.) compose this mixin and call _check_local_cache before any network attempt. The check enforces the air-gapped contract: when local_files_only=True and the cache is missing, the source raises a FileNotFoundError whose message names the dataset and the exact paths the user must populate, instead of a generic "file not found".

Subclasses must define self.local_files_only: bool (typically wired through their config dataclass).

local_files_only instance-attribute ¤

local_files_only: bool

DataSourceModule ¤

DataSourceModule(config: StructuralConfig, *, rngs: Rngs | None = None, name: str | None = None)

Bases: StructuralModule

Enhanced base module for all Datarax data source components.

This class extends StructuralModule for non-parametric structural data loading. Concrete data sources define their own config classes extending StructuralConfig.

A DataSourceModule is responsible for reading data from an external source (e.g., files, memory, network) and yielding data elements as PyTrees. Each data element is typically a dictionary or other PyTree structure containing JAX arrays or Python primitives.

Important: When subclassing, if you store data containing JAX Arrays in an attribute (like self.data), wrap the assigned value with nnx.Param or assignment-time nnx.data(value):

Examples:

@dataclass(frozen=True)
class MyDataSourceConfig(StructuralConfig):
    required_param: int | None = None
    def __post_init__(self):
        super().__post_init__()
        if self.required_param is None:
            raise ValueError("required_param is required")
class MyDataSource(DataSourceModule):
    data: list[dict]
    def __init__(self, config: MyDataSourceConfig, data: list[dict], *,
                 rngs: nnx.Rngs | None = None, name: str | None = None):
        super().__init__(config, rngs=rngs, name=name)
        self.data = nnx.data(data)  # Mark as pytree data, not parameters.

This prevents NNX from trying to track individual JAX Arrays within the data structure as trainable parameters.

Parameters:

Name Type Description Default
config StructuralConfig

Structural module configuration (already validated, frozen)

required
rngs Rngs | None

Random number generators (required if stochastic=True)

None
name str | None

Optional module name

None

Raises:

Type Description
ValueError

If stochastic=True but rngs is None

config instance-attribute ¤

config = static(config)

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = stochastic

stream_name instance-attribute ¤

stream_name = stream_name

get_batch_at ¤

get_batch_at(start: int | Any, size: int, key: Any | None = None) -> Any

Stateless indexed batch access for Pipeline-driven iteration.

Returns size records starting at start. Implementations must be stateless (no mutation of internal counters) and JAX-traceable (must accept tracer values for start) so the call composes with nnx.scan.

Parameters:

Name Type Description Default
start int | Any

Starting index. Sources that support indexed access accept a Python int or a traced jax.Array.

required
size int

Number of records to return (Python int — JAX shapes are static).

required
key Any | None

Optional PRNG key for shuffled or stochastic sampling.

None

Returns:

Type Description
Any

A batch dict (or PyTree) with leading dim size.

Raises:

Type Description
NotImplementedError

If the source does not support indexed access (e.g. forward-only streams). Pipeline falls back to its __iter__ debug path in that case.

element_spec ¤

element_spec() -> Any

Return a PyTree of jax.ShapeDtypeStruct describing per-element output.

Downstream consumers (operators, batchers, models) use this contract to pre-allocate buffers, auto-size learnable layers, and statically validate operator chains. Subclasses MUST override this method.

Returns:

Type Description
Any

A PyTree (typically a dict) whose leaves are jax.ShapeDtypeStruct

Any

instances describing one emitted element.

Raises:

Type Description
NotImplementedError

Always, on the base class. Subclasses must override.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

process ¤

process(input: Any, *args: Any, **kwargs: Any) -> Any

Process input structure.

This method transforms the structure/organization of input data without modifying the data values themselves.

Subclasses MUST implement this method.

The input/output types depend on the specific structural processor:

  • Batcher: list[Element] -> list[Batch]
  • Sampler: int -> list[int]
  • Sharder: Batch -> Sharded[Batch]
  • Splitter: Dataset -> tuple[Dataset, Dataset]

Parameters:

Name Type Description Default
input Any

Input to process (type varies by processor)

required
*args Any

Additional positional arguments (processor-specific)

()
**kwargs Any

Additional keyword arguments (processor-specific)

{}

Returns:

Type Description
Any

Processed output (type varies by processor)

Examples:

Batcher implementation:

def process(self, elements: list[Element]) -> list[Batch]:
    batches = []
    for i in range(0, len(elements), self.config.batch_size):
        batch_elements = elements[i:i + self.config.batch_size]
        batches.append(Batch.from_elements(batch_elements))
    return batches

Sampler implementation (deterministic):

def process(self, dataset_size: int) -> list[int]:
    return list(range(min(self.config.num_samples, dataset_size)))

Sampler implementation (stochastic):

def process(self, dataset_size: int) -> list[int]:
    rng = self.rngs[self.config.stream_name]()
    indices = jax.random.choice(
        rng, dataset_size, shape=(self.config.num_samples,),
        replace=self.config.replacement
    )
    return indices.tolist()