Skip to content

HuggingFace Source¤

HFEagerSource provides seamless integration with HuggingFace Datasets, allowing you to load any of the 100,000+ datasets available on the Hub directly into your Datarax pipelines with automatic JAX array conversion.

Note: You can also use the factory function from_hf(name, split, ...) which auto-selects between eager and streaming modes based on your configuration.

Key Features¤

Feature Description
Automatic conversion TensorFlow/NumPy tensors → JAX arrays
Streaming support Load large datasets without downloading everything
Shuffling Built-in shuffle with configurable buffer size
Key filtering Include/exclude specific dataset fields
Stateful iteration Track position, epoch, and support batch retrieval

★ Insight ─────────────────────────────────────

  • HFEagerSource wraps the datasets library for JAX-native workflows
  • PIL images are automatically converted to JAX arrays
  • Use streaming=True for datasets larger than your disk
  • The get_batch() method enables efficient batch retrieval

─────────────────────────────────────────────────

Installation¤

HFEagerSource requires the datasets package:

pip install datarax[hf]
# or
pip install datasets

Quick Start¤

import flax.nnx as nnx
from datarax.sources import HFEagerSource
from datarax.sources.hf_source import HFEagerConfig

# Load IMDB sentiment dataset
config = HFEagerConfig(name="imdb", split="train")
source = HFEagerSource(config, rngs=nnx.Rngs(0))

# Iterate over elements
for item in source:
    text = item["text"]
    label = item["label"]
    process(text, label)

Batch Retrieval¤

For training loops, use the stateful get_batch() method:

# Get batches of 32 samples
batch = source.get_batch(32)
# batch["text"] has shape (32,)
# batch["label"] has shape (32,)

# Automatic epoch cycling
for step in range(1000):
    batch = source.get_batch(32)
    train_step(batch)

Streaming Large Datasets¤

For datasets too large to download completely:

config = HFEagerConfig(
    name="c4",  # Common Crawl dataset (800GB+)
    split="train",
    streaming=True,
)
source = HFEagerSource(config)

# Data is fetched on-demand
for item in source:
    process(item)

Shuffling¤

Enable shuffling with configurable buffer size:

config = HFEagerConfig(
    name="mnist",
    split="train",
    shuffle=True,
    shuffle_buffer_size=10000,  # Buffer for streaming shuffle
    seed=42,
)
source = HFEagerSource(config, rngs=nnx.Rngs(42))

Field Filtering¤

Select only the fields you need:

# Include only specific fields
config = HFEagerConfig(
    name="glue",
    split="train",
    include_keys={"sentence", "label"},  # Only these fields
)

# Or exclude unwanted fields
config = HFEagerConfig(
    name="imdb",
    split="train",
    exclude_keys={"idx"},  # Everything except idx
)

Integration with DAG Pipelines¤

from datarax.pipeline import Pipeline
from datarax.pipeline import Pipeline

# Build a pipeline
config = HFEagerConfig(name="mnist", split="train")
source = HFEagerSource(config)

pipeline = (
    Pipeline(source=source, stages=[normalize_op, augment_op], batch_size=64, rngs=nnx.Rngs(0)))

for batch in pipeline:
    train_step(batch)

Dataset Information¤

Access metadata about the loaded dataset:

# Get HuggingFace DatasetInfo
info = source.get_dataset_info()
print(f"Description: {info.description}")
print(f"Features: {info.features}")

# Check length (if available)
print(f"Dataset length: {len(source)}")

See Also¤


API Reference¤

datarax.sources.hf_source ¤

HuggingFace Datasets data source implementation for Datarax.

This module provides two distinct source types optimized for different use cases:

HFEagerSource: For small/medium datasets that fit in memory (~10% VRAM) - Loads ALL data to JAX arrays at initialization - Pure JAX iteration after init (no HuggingFace overhead during training) - O(1) memory shuffling via Grain's index_shuffle (Feistel cipher) - Fully checkpointable (just indices, no external state) - Ideal for: MNIST, CIFAR-10, sentiment datasets, small custom datasets

HFStreamingSource: For large datasets that don't fit in memory - Thin wrapper around HuggingFace dataset iterator - Supports HuggingFace's built-in streaming mode - Trade-offs: External iterator state, can't checkpoint mid-epoch - Ideal for: The Pile, C4, large-scale datasets, memory-constrained environments

Architecture Insight

The separation between eager and streaming follows the same pattern as TFDS, ensuring consistent behavior across data backends while optimizing for each use case's specific requirements.

logger module-attribute ¤

logger = getLogger(__name__)

HFEagerConfig dataclass ¤

HFEagerConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, name: str | None = None, split: str | None = None, data_dir: str | None = None, include_keys: set[str] | None = None, exclude_keys: set[str] | None = None, shuffle: bool = False, seed: int = 42, download_kwargs: dict[str, Any] | None = None, local_files_only: bool = False)

Bases: SourceConfigBase

Configuration for HFEagerSource (loads all data to JAX at init).

Configuration for eager-loading HuggingFace datasets into JAX arrays.

Parameters:

Name Type Description Default
name str | None

Name of the dataset in HuggingFace Hub (required)

None
split str | None

Split of the dataset to load, e.g., "train", "test" (required)

None
data_dir str | None

Optional directory where the dataset is stored/downloaded

None
shuffle bool

Whether to shuffle the dataset during iteration

False
seed int

Integer seed for Grain's index_shuffle (default: 42)

42
download_kwargs dict[str, Any] | None

Optional keyword arguments for load_dataset

None
include_keys set[str] | None

Optional set of keys to include in output (exclusive with exclude_keys)

None
exclude_keys set[str] | None

Optional set of keys to exclude from output (exclusive with include_keys)

None
Note

The seed parameter is an integer (not JAX RNG key) for Grain's index_shuffle. This ensures O(1) memory shuffling and reproducible per-epoch seeds.

shuffle class-attribute instance-attribute ¤

shuffle: bool = False

seed class-attribute instance-attribute ¤

seed: int = 42

download_kwargs class-attribute instance-attribute ¤

download_kwargs: dict[str, Any] | None = None

local_files_only class-attribute instance-attribute ¤

local_files_only: bool = False

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

stochastic class-attribute instance-attribute ¤

stochastic: bool = False

stream_name class-attribute instance-attribute ¤

stream_name: str | None = None

name class-attribute instance-attribute ¤

name: str | None = None

split class-attribute instance-attribute ¤

split: str | None = None

data_dir class-attribute instance-attribute ¤

data_dir: str | None = None

include_keys class-attribute instance-attribute ¤

include_keys: set[str] | None = None

exclude_keys class-attribute instance-attribute ¤

exclude_keys: set[str] | None = None

HFStreamingConfig dataclass ¤

HFStreamingConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, name: str | None = None, split: str | None = None, data_dir: str | None = None, include_keys: set[str] | None = None, exclude_keys: set[str] | None = None, streaming: bool = False, shuffle: bool = False, shuffle_buffer_size: int = 1000, download_kwargs: dict[str, Any] | None = None, local_files_only: bool = False)

Bases: SourceConfigBase

Configuration for HFStreamingSource (streams data from HF dataset).

Use this for datasets too large to fit in memory or when using HuggingFace's built-in streaming mode for efficient data loading.

Parameters:

Name Type Description Default
name str | None

Name of the dataset in HuggingFace Hub (required)

None
split str | None

Split of the dataset to load, e.g., "train", "test" (required)

None
data_dir str | None

Optional directory where the dataset is stored/downloaded

None
streaming bool

Whether to use HuggingFace streaming mode (default: False)

False
shuffle bool

Whether to shuffle the dataset

False
shuffle_buffer_size int

Buffer size for shuffling in streaming mode (default: 1000)

1000
download_kwargs dict[str, Any] | None

Optional keyword arguments for load_dataset

None
include_keys set[str] | None

Optional set of keys to include in output

None
exclude_keys set[str] | None

Optional set of keys to exclude from output

None

streaming class-attribute instance-attribute ¤

streaming: bool = False

shuffle class-attribute instance-attribute ¤

shuffle: bool = False

shuffle_buffer_size class-attribute instance-attribute ¤

shuffle_buffer_size: int = 1000

download_kwargs class-attribute instance-attribute ¤

download_kwargs: dict[str, Any] | None = None

local_files_only class-attribute instance-attribute ¤

local_files_only: bool = False

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

stochastic class-attribute instance-attribute ¤

stochastic: bool = False

stream_name class-attribute instance-attribute ¤

stream_name: str | None = None

name class-attribute instance-attribute ¤

name: str | None = None

split class-attribute instance-attribute ¤

split: str | None = None

data_dir class-attribute instance-attribute ¤

data_dir: str | None = None

include_keys class-attribute instance-attribute ¤

include_keys: set[str] | None = None

exclude_keys class-attribute instance-attribute ¤

exclude_keys: set[str] | None = None

HFEagerSource ¤

HFEagerSource(config: HFEagerConfig, *, rngs: Rngs | None = None, name: str | None = None)

Bases: EagerSourceBase

Eager-loading HuggingFace source for small/medium datasets.

Loads ALL data to JAX arrays at initialization, then operates like a MemorySource with pure JAX operations. Use for datasets that fit in ~10% of device VRAM.

Key Features
  • One-time conversion at init (PIL→numpy→JAX for images)
  • Pure JAX iteration after init
  • O(1) memory shuffling via Grain's index_shuffle (Feistel cipher)
  • Full checkpointing support (indices only, no external state)
  • Automatic PIL Image to JAX array conversion
Performance
  • Training loops can use lax.fori_loop for 100-500x speedup
  • Device placement via collect_to_array() for staged training
Example
# Create eager source for MNIST from HuggingFace
config = HFEagerConfig(name="mnist", split="train", shuffle=True)
source = HFEagerSource(config, rngs=nnx.Rngs(0))

# Iterate - pure JAX, no HF overhead
for item in source:
    process(item["image"])

# Get batch (stateless with key, or stateful without)
batch = source.get_batch(32)  # Stateful
batch = source.get_batch(32, key=jax.random.key(0))  # Stateless

Parameters:

Name Type Description Default
config HFEagerConfig

Configuration for the source

required
rngs Rngs | None

Optional RNG state for shuffling

None
name str | None

Optional name (defaults to HFEagerSource(dataset:split))

None

Raises:

Type Description
ImportError

If the datasets package is not installed

dataset_name instance-attribute ¤

dataset_name = name

split_name instance-attribute ¤

split_name = split

include_keys instance-attribute ¤

include_keys = include_keys

exclude_keys instance-attribute ¤

exclude_keys = exclude_keys

data instance-attribute ¤

data: dict[str, Any] = data(_load_all_from_backend_to_jax(config))

length instance-attribute ¤

length = _infer_hf_column_length(data)

index instance-attribute ¤

index = Variable(0)

epoch instance-attribute ¤

epoch = Variable(0)

config instance-attribute ¤

config = static(config)

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = stochastic

stream_name instance-attribute ¤

stream_name = stream_name

is_random_order property ¤

is_random_order: bool

Whether iteration order is randomized.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

process ¤

process(input: Any, *args: Any, **kwargs: Any) -> Any

Process input structure.

This method transforms the structure/organization of input data without modifying the data values themselves.

Subclasses MUST implement this method.

The input/output types depend on the specific structural processor:

  • Batcher: list[Element] -> list[Batch]
  • Sampler: int -> list[int]
  • Sharder: Batch -> Sharded[Batch]
  • Splitter: Dataset -> tuple[Dataset, Dataset]

Parameters:

Name Type Description Default
input Any

Input to process (type varies by processor)

required
*args Any

Additional positional arguments (processor-specific)

()
**kwargs Any

Additional keyword arguments (processor-specific)

{}

Returns:

Type Description
Any

Processed output (type varies by processor)

Examples:

Batcher implementation:

def process(self, elements: list[Element]) -> list[Batch]:
    batches = []
    for i in range(0, len(elements), self.config.batch_size):
        batch_elements = elements[i:i + self.config.batch_size]
        batches.append(Batch.from_elements(batch_elements))
    return batches

Sampler implementation (deterministic):

def process(self, dataset_size: int) -> list[int]:
    return list(range(min(self.config.num_samples, dataset_size)))

Sampler implementation (stochastic):

def process(self, dataset_size: int) -> list[int]:
    rng = self.rngs[self.config.stream_name]()
    indices = jax.random.choice(
        rng, dataset_size, shape=(self.config.num_samples,),
        replace=self.config.replacement
    )
    return indices.tolist()

get_batch_at ¤

get_batch_at(start: int | Array, size: int, key: Array | None = None) -> dict[str, Any]

Stateless indexed batch access; JIT-traceable for scan-based iteration.

Returns size records starting at logical position start. Does not advance self.index or any other internal state, so callers (typically Pipeline) can drive iteration via their own position counter and trace get_batch_at under nnx.scan / nnx.jit.

Two modes:

  • Sequential (self.is_random_order == False): returns the contiguous slice data[start : start + size] with wrap-around at the end of the source.
  • Shuffled (self.is_random_order == True): applies a deterministic permutation derived from key and returns the slice of that permutation. Same (start, size, key) always returns the same output. The permutation is materialized via jax.random.permutation(key, length) per call — O(length) per batch.

Parameters:

Name Type Description Default
start int | Array

Starting logical index; accepts concrete int or traced jax.Array.

required
size int

Number of records to return (Python int — JAX shapes are static).

required
key Array | None

PRNG key for shuffled mode. Required when is_random_order=True; ignored otherwise.

None

Returns:

Type Description
dict[str, Any]

Dict mapping each data key to a JAX array with leading

dict[str, Any]

dimension size.

element_spec ¤

element_spec() -> Any

Derive per-element spec from the eager dict-of-arrays storage.

EagerSourceBase subclasses store data as a dict mapping keys to arrays whose leading axis is the dataset size. This default implementation strips that leading axis from every leaf to produce one jax.ShapeDtypeStruct per key.

Subclasses with non-dict storage should override.

Raises:

Type Description
ValueError

If the source is empty.

get_batch ¤

get_batch(batch_size: int, key: Array | None = None) -> dict[str, Any]

Get one eager batch in stateful or stateless mode.

get_dataset_info ¤

get_dataset_info() -> Any

Return cached backend-specific dataset metadata.

reset ¤

reset(seed: int | None = None) -> None

Reset eager-source iteration state.

set_random_order ¤

set_random_order(enabled: bool) -> None

Update runtime random-order behavior.

HFStreamingSource ¤

HFStreamingSource(config: HFStreamingConfig, *, rngs: Rngs | None = None, name: str | None = None)

Bases: StreamingSourceBase

Streaming HuggingFace source for large datasets.

Thin wrapper around HuggingFace dataset for data that can't fit in memory. Supports HuggingFace's native streaming mode for efficient large-scale data loading.

Key Features
  • Native HuggingFace streaming support
  • Automatic PIL Image to JAX array conversion
  • include_keys/exclude_keys filtering
  • Revision pinning for security (B615)
Trade-offs vs Eager
  • Cannot checkpoint mid-epoch (external iterator state)
  • Use with prefetch_to_device() for best results
Example
# Create streaming source for large dataset
config = HFStreamingConfig(name="allenai/c4", split="train", streaming=True)
source = HFStreamingSource(config, rngs=nnx.Rngs(0))

# Iterate with prefetching
for batch in prefetch_to_device(source, size=2):
    train_step(batch)

Parameters:

Name Type Description Default
config HFStreamingConfig

Configuration for the source

required
rngs Rngs | None

Optional RNG state

None
name str | None

Optional name (defaults to HFStreamingSource(dataset:split))

None

Raises:

Type Description
ImportError

If the datasets package is not installed

config instance-attribute ¤

epoch instance-attribute ¤

epoch = Variable(0)

dataset_name property ¤

dataset_name: str | None

Dataset name from source config.

split_name property ¤

split_name: str | None

Dataset split from source config.

is_iterable_mode property ¤

is_iterable_mode: bool

Streaming mode flag from source config.

random_order_buffer_depth property ¤

random_order_buffer_depth: int

Shuffle buffer size from source config.

selected_keys property ¤

selected_keys: set[str] | None

Optional key-include filter from source config.

rejected_keys property ¤

rejected_keys: set[str] | None

Optional key-exclude filter from source config.

length property ¤

length: int | None

Dataset length when known.

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = stochastic

stream_name instance-attribute ¤

stream_name = stream_name

is_random_order property ¤

is_random_order: bool

Whether iteration order is randomized.

element_spec ¤

element_spec() -> Any

Return per-element shape/dtype derived by peeking the backend.

Streaming sources cannot strip a leading dataset-size dimension because each iteration yields one element. The spec is derived by peeking the first element from the underlying HuggingFace dataset (without consuming the iterator state for normal training) and converting each top-level value into a single ShapeDtypeStruct.

Top-level dict values are treated as single arrays (HuggingFace commonly emits Python lists for vector features; those become 1-D arrays, not nested per-element scalars).

The peek operates on the cached self._hf_dataset (already loaded in __init__) so it does not re-trigger downloads and is safe to call repeatedly.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

process ¤

process(input: Any, *args: Any, **kwargs: Any) -> Any

Process input structure.

This method transforms the structure/organization of input data without modifying the data values themselves.

Subclasses MUST implement this method.

The input/output types depend on the specific structural processor:

  • Batcher: list[Element] -> list[Batch]
  • Sampler: int -> list[int]
  • Sharder: Batch -> Sharded[Batch]
  • Splitter: Dataset -> tuple[Dataset, Dataset]

Parameters:

Name Type Description Default
input Any

Input to process (type varies by processor)

required
*args Any

Additional positional arguments (processor-specific)

()
**kwargs Any

Additional keyword arguments (processor-specific)

{}

Returns:

Type Description
Any

Processed output (type varies by processor)

Examples:

Batcher implementation:

def process(self, elements: list[Element]) -> list[Batch]:
    batches = []
    for i in range(0, len(elements), self.config.batch_size):
        batch_elements = elements[i:i + self.config.batch_size]
        batches.append(Batch.from_elements(batch_elements))
    return batches

Sampler implementation (deterministic):

def process(self, dataset_size: int) -> list[int]:
    return list(range(min(self.config.num_samples, dataset_size)))

Sampler implementation (stochastic):

def process(self, dataset_size: int) -> list[int]:
    rng = self.rngs[self.config.stream_name]()
    indices = jax.random.choice(
        rng, dataset_size, shape=(self.config.num_samples,),
        replace=self.config.replacement
    )
    return indices.tolist()

get_batch_at ¤

get_batch_at(start: int | Any, size: int, key: Any | None = None) -> Any

Stateless indexed batch access for Pipeline-driven iteration.

Returns size records starting at start. Implementations must be stateless (no mutation of internal counters) and JAX-traceable (must accept tracer values for start) so the call composes with nnx.scan.

Parameters:

Name Type Description Default
start int | Any

Starting index. Sources that support indexed access accept a Python int or a traced jax.Array.

required
size int

Number of records to return (Python int — JAX shapes are static).

required
key Any | None

Optional PRNG key for shuffled or stochastic sampling.

None

Returns:

Type Description
Any

A batch dict (or PyTree) with leading dim size.

Raises:

Type Description
NotImplementedError

If the source does not support indexed access (e.g. forward-only streams). Pipeline falls back to its __iter__ debug path in that case.

get_dataset_info ¤

get_dataset_info() -> Any

Return cached backend-specific dataset metadata.

reset ¤

reset(seed: int | None = None) -> None

Reset streaming iterator state.

get_batch ¤

get_batch(batch_size: int) -> dict[str, Any]

Collect up to batch_size items from the streaming iterator.