Array Record Source¤
Data source for ArrayRecord format (efficient columnar storage).
See Also¤
- Sources Overview - All data sources
- Data Sources Guide - Full guide
- TFDS Source - TensorFlow Datasets
- Performance Tools - Optimization
datarax.sources.array_record_source ¤
Data source for reading from ArrayRecord format files.
ArrayRecordSourceConfig
dataclass
¤
ArrayRecordSourceConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, seed: int = 42, num_epochs: int = -1, shuffle_files: bool = False, local_files_only: bool = False)
Bases: StructuralConfig
Configuration for ArrayRecordSourceModule.
Inherits from StructuralConfig for runtime immutability.
Attributes:
| Name | Type | Description |
|---|---|---|
seed |
int
|
Random seed for shuffling (used internally, not by Grain). |
num_epochs |
int
|
Number of epochs (-1 for infinite). |
shuffle_files |
bool
|
Whether to shuffle file order (handled internally). |
local_files_only |
bool
|
If True, validate every path exists at construction
time and raise |
precomputed_stats
class-attribute
instance-attribute
¤
ArrayRecordSourceModule ¤
ArrayRecordSourceModule(config: ArrayRecordSourceConfig, paths: str | list[str], *, rngs: Rngs | None = None, name: str | None = None)
Bases: DataSourceModule
Stateful wrapper for Grain's ArrayRecordDataSource.
This module wraps Grain's ArrayRecordDataSource while maintaining stateful iteration through NNX Variables, following TDD principles and critical technical guidelines.
Note: Grain's ArrayRecordDataSource doesn't accept a seed parameter directly. Shuffling is handled at the sampler level or through file ordering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ArrayRecordSourceConfig
|
Configuration for the source. |
required |
paths
|
str | list[str]
|
Path pattern or list of paths to ArrayRecord files. |
required |
rngs
|
Rngs | None
|
NNX Rngs for additional randomness. |
None
|
name
|
str | None
|
Optional name for the module. |
None
|
get_batch_at ¤
Stateless indexed batch access for Pipeline-driven iteration.
Returns size records starting at logical position start,
wrapping at the end of the dataset and applying any active
shuffle permutation. Does not advance self.current_index or
any other internal state.
ArrayRecord records are loaded host-side (Grain is a Python
library), so this method requires a concrete Python int for
start. Driving an ArrayRecord source under nnx.scan
(Tier C of the pipeline integration story) currently requires
wrapping the host-side fetch in jax.experimental.io_callback
— left as a future enhancement. Tier A (Python iteration) and
Tier B (single step()) work today.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
int
|
Concrete starting index (Python int). |
required |
size
|
int
|
Number of records to return. |
required |
key
|
Any | None
|
Reserved for future shuffled-mode support; currently
ignored (shuffle uses |
None
|
Returns:
| Type | Description |
|---|---|
list[Any]
|
List of |
list[Any]
|
Grain source. Records are typically Python dicts; callers |
list[Any]
|
(typically a parse / decode operator) handle structure. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
process ¤
Process input structure.
This method transforms the structure/organization of input data without modifying the data values themselves.
Subclasses MUST implement this method.
The input/output types depend on the specific structural processor:
- Batcher: list[Element] -> list[Batch]
- Sampler: int -> list[int]
- Sharder: Batch -> Sharded[Batch]
- Splitter: Dataset -> tuple[Dataset, Dataset]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Any
|
Input to process (type varies by processor) |
required |
*args
|
Any
|
Additional positional arguments (processor-specific) |
()
|
**kwargs
|
Any
|
Additional keyword arguments (processor-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Processed output (type varies by processor) |
Examples:
Batcher implementation:
def process(self, elements: list[Element]) -> list[Batch]:
batches = []
for i in range(0, len(elements), self.config.batch_size):
batch_elements = elements[i:i + self.config.batch_size]
batches.append(Batch.from_elements(batch_elements))
return batches
Sampler implementation (deterministic):
def process(self, dataset_size: int) -> list[int]:
return list(range(min(self.config.num_samples, dataset_size)))
Sampler implementation (stochastic):
element_spec ¤
element_spec() -> Any
Return a PyTree of jax.ShapeDtypeStruct describing per-element output.
Downstream consumers (operators, batchers, models) use this contract to pre-allocate buffers, auto-size learnable layers, and statically validate operator chains. Subclasses MUST override this method.
Returns:
| Type | Description |
|---|---|
Any
|
A PyTree (typically a dict) whose leaves are |
Any
|
instances describing one emitted element. |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Always, on the base class. Subclasses must override. |