Data Source Protocol¤
Core protocol for data sources.
See Also¤
- Core Overview - All core protocols
- Sources - Source implementations
- HF Source - HuggingFace integration
- Data Sources Guide
datarax.core.data_source ¤
Base module for data sources in Datarax.
This module defines the base class for all Datarax data source components that use flax.nnx.Module for state management and JAX transformation compatibility.
LocalFilesOnlyMixin ¤
Adds a uniform local_files_only flag to data sources.
Sources that download external archives (HuggingFace, TFDS, ArrayRecord,
etc.) compose this mixin and call _check_local_cache before any
network attempt. The check enforces the air-gapped contract: when
local_files_only=True and the cache is missing, the source raises a
FileNotFoundError whose message names the dataset and the exact paths
the user must populate, instead of a generic "file not found".
Subclasses must define self.local_files_only: bool (typically wired
through their config dataclass).
DataSourceModule ¤
DataSourceModule(config: StructuralConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: StructuralModule
Enhanced base module for all Datarax data source components.
This class extends StructuralModule for non-parametric structural data loading. Concrete data sources define their own config classes extending StructuralConfig.
A DataSourceModule is responsible for reading data from an external source (e.g., files, memory, network) and yielding data elements as PyTrees. Each data element is typically a dictionary or other PyTree structure containing JAX arrays or Python primitives.
Important: When subclassing, if you store data containing JAX Arrays in
an attribute (like self.data), wrap the assigned value with nnx.Param
or assignment-time nnx.data(value):
Examples:
@dataclass(frozen=True)
class MyDataSourceConfig(StructuralConfig):
required_param: int | None = None
def __post_init__(self):
super().__post_init__()
if self.required_param is None:
raise ValueError("required_param is required")
class MyDataSource(DataSourceModule):
data: list[dict]
def __init__(self, config: MyDataSourceConfig, data: list[dict], *,
rngs: nnx.Rngs | None = None, name: str | None = None):
super().__init__(config, rngs=rngs, name=name)
self.data = nnx.data(data) # Mark as pytree data, not parameters.
This prevents NNX from trying to track individual JAX Arrays within the data structure as trainable parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
StructuralConfig
|
Structural module configuration (already validated, frozen) |
required |
rngs
|
Rngs | None
|
Random number generators (required if stochastic=True) |
None
|
name
|
str | None
|
Optional module name |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If stochastic=True but rngs is None |
get_batch_at ¤
Stateless indexed batch access for Pipeline-driven iteration.
Returns size records starting at start. Implementations must
be stateless (no mutation of internal counters) and JAX-traceable
(must accept tracer values for start) so the call composes with
nnx.scan.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
int | Any
|
Starting index. Sources that support indexed access
accept a Python int or a traced |
required |
size
|
int
|
Number of records to return (Python int — JAX shapes are static). |
required |
key
|
Any | None
|
Optional PRNG key for shuffled or stochastic sampling. |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
A batch dict (or PyTree) with leading dim |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If the source does not support indexed
access (e.g. forward-only streams). Pipeline falls back to
its |
element_spec ¤
element_spec() -> Any
Return a PyTree of jax.ShapeDtypeStruct describing per-element output.
Downstream consumers (operators, batchers, models) use this contract to pre-allocate buffers, auto-size learnable layers, and statically validate operator chains. Subclasses MUST override this method.
Returns:
| Type | Description |
|---|---|
Any
|
A PyTree (typically a dict) whose leaves are |
Any
|
instances describing one emitted element. |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Always, on the base class. Subclasses must override. |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
process ¤
Process input structure.
This method transforms the structure/organization of input data without modifying the data values themselves.
Subclasses MUST implement this method.
The input/output types depend on the specific structural processor:
- Batcher: list[Element] -> list[Batch]
- Sampler: int -> list[int]
- Sharder: Batch -> Sharded[Batch]
- Splitter: Dataset -> tuple[Dataset, Dataset]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Any
|
Input to process (type varies by processor) |
required |
*args
|
Any
|
Additional positional arguments (processor-specific) |
()
|
**kwargs
|
Any
|
Additional keyword arguments (processor-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Processed output (type varies by processor) |
Examples:
Batcher implementation:
def process(self, elements: list[Element]) -> list[Batch]:
batches = []
for i in range(0, len(elements), self.config.batch_size):
batch_elements = elements[i:i + self.config.batch_size]
batches.append(Batch.from_elements(batch_elements))
return batches
Sampler implementation (deterministic):
def process(self, dataset_size: int) -> list[int]:
return list(range(min(self.config.num_samples, dataset_size)))
Sampler implementation (stochastic):