Skip to content

TensorFlow Datasets Source¤

TFDSEagerSource provides integration with TensorFlow Datasets (TFDS), giving access to hundreds of ready-to-use datasets with automatic conversion from TensorFlow tensors to JAX arrays.

Note: You can also use the factory function from_tfds(name, split, ...) which auto-selects between eager and streaming modes based on your configuration.

Key Features¤

Feature Description
Automatic conversion TensorFlow tensors → JAX arrays
Built-in prefetching Uses tf.data.AUTOTUNE for performance
Supervised mode Optional (image, label) tuple unpacking
Shuffling TensorFlow-native shuffle with buffer
Caching Optional dataset caching for repeated epochs

★ Insight ─────────────────────────────────────

  • TFDS handles download and preparation automatically
  • Use as_supervised=True to get {"image": ..., "label": ...} format
  • TensorFlow's prefetching is applied automatically for performance
  • The source tracks epoch and index for stateful training loops

─────────────────────────────────────────────────

Installation¤

TFDSEagerSource requires TensorFlow and tensorflow-datasets:

pip install datarax[data]
# or
pip install tensorflow tensorflow-datasets

Quick Start¤

import flax.nnx as nnx
from datarax.sources import TFDSEagerSource
from datarax.sources.tfds_source import TFDSEagerConfig

# Load MNIST dataset
config = TFDSEagerConfig(name="mnist", split="train")
source = TFDSEagerSource(config, rngs=nnx.Rngs(0))

# Iterate over elements
for item in source:
    image = item["image"]  # JAX array, shape (28, 28, 1)
    label = item["label"]  # JAX array, scalar
    process(image, label)

Supervised Mode¤

Get a cleaner {"image", "label"} structure:

config = TFDSEagerConfig(
    name="cifar10",
    split="train",
    as_supervised=True,  # Returns {"image": ..., "label": ...}
)
source = TFDSEagerSource(config)

batch = source.get_batch(32)
images = batch["image"]  # Shape: (32, 32, 32, 3)
labels = batch["label"]  # Shape: (32,)

Batch Retrieval¤

For training loops with automatic epoch cycling:

# Stateful batch retrieval
for step in range(10000):
    batch = source.get_batch(64)
    loss = train_step(batch)

    # Check progress
    print(f"Epoch {source.epoch.get_value()}, "
          f"Index {source.index.get_value()}")

Shuffling¤

Enable shuffling with configurable buffer:

config = TFDSEagerConfig(
    name="imagenet2012",
    split="train",
    shuffle=True,
    shuffle_buffer_size=10000,
    seed=42,
)
source = TFDSEagerSource(config, rngs=nnx.Rngs(42))

Caching¤

Cache the dataset for faster repeated epochs:

config = TFDSEagerConfig(
    name="mnist",
    split="train",
    cacheable=True,  # Cache in memory after first epoch
)
source = TFDSEagerSource(config)

Custom Data Directory¤

Store datasets in a specific location:

config = TFDSEagerConfig(
    name="imagenet2012",
    split="train",
    data_dir="/path/to/tfds_data",
)
source = TFDSEagerSource(config)

Field Filtering¤

Select only needed fields:

config = TFDSEagerConfig(
    name="coco/2017",
    split="train",
    include_keys={"image", "objects"},  # Only these fields
)

# Or exclude unwanted fields
config = TFDSEagerConfig(
    name="mnist",
    split="train",
    exclude_keys={"id"},
)

Integration with DAG Pipelines¤

from datarax.pipeline import Pipeline
from datarax.pipeline import Pipeline

config = TFDSEagerConfig(
    name="cifar10",
    split="train",
    as_supervised=True,
)
source = TFDSEagerSource(config)

pipeline = (
    Pipeline(source=source, stages=[normalize_op, augment_op], batch_size=128, rngs=nnx.Rngs(0)))

for batch in pipeline:
    train_step(batch)

Dataset Information¤

Access rich metadata from TFDS:

info = source.get_dataset_info()
print(f"Description: {info.description}")
print(f"Features: {info.features}")
print(f"Splits: {list(info.splits.keys())}")
print(f"Citation: {info.citation}")

# Number of examples
print(f"Train examples: {info.splits['train'].num_examples}")

See Also¤


API Reference¤

datarax.sources.tfds_source ¤

TensorFlow Datasets (TFDS) data source implementation for Datarax.

This module provides two distinct source types optimized for different use cases:

TFDSEagerSource: For small/medium datasets that fit in memory (~10% VRAM) - Loads ALL data to JAX arrays at initialization - Pure JAX iteration after init (no TensorFlow overhead during training) - O(1) memory shuffling via Grain's index_shuffle (Feistel cipher) - Fully checkpointable (just indices, no external state) - Ideal for: MNIST, CIFAR-10, Fashion-MNIST, small custom datasets

TFDSStreamingSource: For large datasets that don't fit in memory - Thin wrapper around TF dataset iterator - DLPack zero-copy conversion for each batch - Fixed prefetch buffer (no AUTOTUNE thread storms) - Trade-offs: External iterator state, can't checkpoint mid-epoch - Ideal for: ImageNet, large-scale datasets, memory-constrained environments

Architecture Insight

The ~0.4s delay at epoch 2 in previous implementations was caused by TensorFlow's AUTOTUNE prefetch spawning background threads during epoch transitions. TFDSEagerSource eliminates this entirely by loading all data upfront. TFDSStreamingSource uses fixed prefetch to prevent thread storms.

logger module-attribute ¤

logger = getLogger(__name__)

TFDSEagerConfig dataclass ¤

TFDSEagerConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, name: str | None = None, split: str | None = None, data_dir: str | None = None, include_keys: set[str] | None = None, exclude_keys: set[str] | None = None, try_gcs: bool = False, shuffle: bool = False, seed: int = 42, as_supervised: bool = False, download_and_prepare_kwargs: dict[str, Any] | None = None, beam_num_workers: int | None = None, local_files_only: bool = False)

Bases: SourceConfigBase

Configuration for TFDSEagerSource (loads all data to JAX at init).

Configuration for eager-loading TensorFlow Datasets into JAX arrays.

Parameters:

Name Type Description Default
name str | None

Name of the dataset in TFDS (required)

None
split str | None

Split of the dataset to load, e.g., "train", "test" (required)

None
data_dir str | None

Optional directory where the dataset is stored/downloaded

None
shuffle bool

Whether to shuffle the dataset during iteration

False
seed int

Integer seed for Grain's index_shuffle (default: 42)

42
as_supervised bool

If True, returns 'image'/'label' keys instead of original features

False
download_and_prepare_kwargs dict[str, Any] | None

Optional keyword arguments for download_and_prepare

None
include_keys set[str] | None

Optional set of keys to include in output (exclusive with exclude_keys)

None
exclude_keys set[str] | None

Optional set of keys to exclude from output (exclusive with include_keys)

None
Note

The seed parameter is an integer (not JAX RNG key) for Grain's index_shuffle. This ensures O(1) memory shuffling and reproducible per-epoch seeds.

try_gcs class-attribute instance-attribute ¤

try_gcs: bool = False

shuffle class-attribute instance-attribute ¤

shuffle: bool = False

seed class-attribute instance-attribute ¤

seed: int = 42

as_supervised class-attribute instance-attribute ¤

as_supervised: bool = False

download_and_prepare_kwargs class-attribute instance-attribute ¤

download_and_prepare_kwargs: dict[str, Any] | None = None

beam_num_workers class-attribute instance-attribute ¤

beam_num_workers: int | None = None

local_files_only class-attribute instance-attribute ¤

local_files_only: bool = False

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

stochastic class-attribute instance-attribute ¤

stochastic: bool = False

stream_name class-attribute instance-attribute ¤

stream_name: str | None = None

name class-attribute instance-attribute ¤

name: str | None = None

split class-attribute instance-attribute ¤

split: str | None = None

data_dir class-attribute instance-attribute ¤

data_dir: str | None = None

include_keys class-attribute instance-attribute ¤

include_keys: set[str] | None = None

exclude_keys class-attribute instance-attribute ¤

exclude_keys: set[str] | None = None

TFDSStreamingConfig dataclass ¤

TFDSStreamingConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, name: str | None = None, split: str | None = None, data_dir: str | None = None, include_keys: set[str] | None = None, exclude_keys: set[str] | None = None, try_gcs: bool = False, shuffle: bool = False, shuffle_buffer_size: int = 1000, as_supervised: bool = False, download_and_prepare_kwargs: dict[str, Any] | None = None, beam_num_workers: int | None = None, prefetch_buffer: int = 2, local_files_only: bool = False)

Bases: SourceConfigBase

Configuration for TFDSStreamingSource (streams data from TF dataset).

Use this for datasets too large to fit in memory. The streaming source uses fixed prefetch buffers to avoid AUTOTUNE thread storms.

Parameters:

Name Type Description Default
name str | None

Name of the dataset in TFDS (required)

None
split str | None

Split of the dataset to load, e.g., "train", "test" (required)

None
data_dir str | None

Optional directory where the dataset is stored/downloaded

None
shuffle bool

Whether to shuffle the dataset

False
shuffle_buffer_size int

TF shuffle buffer size (default: 1000)

1000
as_supervised bool

If True, returns 'image'/'label' keys

False
download_and_prepare_kwargs dict[str, Any] | None

Optional keyword arguments for download_and_prepare

None
include_keys set[str] | None

Optional set of keys to include in output

None
exclude_keys set[str] | None

Optional set of keys to exclude from output

None
prefetch_buffer int

Fixed prefetch buffer size (default: 2, NOT AUTOTUNE)

2
Note

The prefetch_buffer uses a fixed size instead of TF AUTOTUNE to prevent background thread storms that cause delays during epoch transitions.

try_gcs class-attribute instance-attribute ¤

try_gcs: bool = False

shuffle class-attribute instance-attribute ¤

shuffle: bool = False

shuffle_buffer_size class-attribute instance-attribute ¤

shuffle_buffer_size: int = 1000

as_supervised class-attribute instance-attribute ¤

as_supervised: bool = False

download_and_prepare_kwargs class-attribute instance-attribute ¤

download_and_prepare_kwargs: dict[str, Any] | None = None

beam_num_workers class-attribute instance-attribute ¤

beam_num_workers: int | None = None

prefetch_buffer class-attribute instance-attribute ¤

prefetch_buffer: int = 2

local_files_only class-attribute instance-attribute ¤

local_files_only: bool = False

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

stochastic class-attribute instance-attribute ¤

stochastic: bool = False

stream_name class-attribute instance-attribute ¤

stream_name: str | None = None

name class-attribute instance-attribute ¤

name: str | None = None

split class-attribute instance-attribute ¤

split: str | None = None

data_dir class-attribute instance-attribute ¤

data_dir: str | None = None

include_keys class-attribute instance-attribute ¤

include_keys: set[str] | None = None

exclude_keys class-attribute instance-attribute ¤

exclude_keys: set[str] | None = None

TFDSEagerSource ¤

TFDSEagerSource(config: TFDSEagerConfig, *, rngs: Rngs | None = None, name: str | None = None)

Bases: EagerSourceBase

Eager-loading TFDS source for small/medium datasets.

Loads ALL data to JAX arrays at initialization, then operates like a MemorySource with pure JAX operations. Use for datasets that fit in ~10% of device VRAM.

Key Features
  • One-time TF→JAX conversion at init (DLPack zero-copy when possible)
  • Pure JAX iteration after init (no TF threads during training)
  • O(1) memory shuffling via Grain's index_shuffle (Feistel cipher)
  • Full checkpointing support (indices only, no external state)
  • Supports as_supervised mode and key filtering
Performance
  • Eliminates ~0.4s epoch 2 delay from TF AUTOTUNE threads
  • Training loops can use lax.fori_loop for 100-500x speedup
  • Device placement via collect_to_array() for staged training
Example
# Create eager source for MNIST
config = TFDSEagerConfig(name="mnist", split="train", shuffle=True)
source = TFDSEagerSource(config, rngs=nnx.Rngs(0))

# Iterate - pure JAX, no TF overhead
for item in source:
    process(item["image"])

# Get batch (stateless with key, or stateful without)
batch = source.get_batch(32)  # Stateful
batch = source.get_batch(32, key=jax.random.key(0))  # Stateless

Parameters:

Name Type Description Default
config TFDSEagerConfig

Configuration for the source

required
rngs Rngs | None

Optional RNG state for shuffling

None
name str | None

Optional name (defaults to TFDSEagerSource(dataset:split))

None

dataset_name instance-attribute ¤

dataset_name = name

split_name instance-attribute ¤

split_name = split

as_supervised instance-attribute ¤

as_supervised = as_supervised

include_keys instance-attribute ¤

include_keys = include_keys

exclude_keys instance-attribute ¤

exclude_keys = exclude_keys

data instance-attribute ¤

data: dict[str, Array] = data(_load_all_from_backend_to_jax(config))

length instance-attribute ¤

length = shape[0]

index instance-attribute ¤

index = Variable(0)

epoch instance-attribute ¤

epoch = Variable(0)

config instance-attribute ¤

config = static(config)

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = stochastic

stream_name instance-attribute ¤

stream_name = stream_name

is_random_order property ¤

is_random_order: bool

Whether iteration order is randomized.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

process ¤

process(input: Any, *args: Any, **kwargs: Any) -> Any

Process input structure.

This method transforms the structure/organization of input data without modifying the data values themselves.

Subclasses MUST implement this method.

The input/output types depend on the specific structural processor:

  • Batcher: list[Element] -> list[Batch]
  • Sampler: int -> list[int]
  • Sharder: Batch -> Sharded[Batch]
  • Splitter: Dataset -> tuple[Dataset, Dataset]

Parameters:

Name Type Description Default
input Any

Input to process (type varies by processor)

required
*args Any

Additional positional arguments (processor-specific)

()
**kwargs Any

Additional keyword arguments (processor-specific)

{}

Returns:

Type Description
Any

Processed output (type varies by processor)

Examples:

Batcher implementation:

def process(self, elements: list[Element]) -> list[Batch]:
    batches = []
    for i in range(0, len(elements), self.config.batch_size):
        batch_elements = elements[i:i + self.config.batch_size]
        batches.append(Batch.from_elements(batch_elements))
    return batches

Sampler implementation (deterministic):

def process(self, dataset_size: int) -> list[int]:
    return list(range(min(self.config.num_samples, dataset_size)))

Sampler implementation (stochastic):

def process(self, dataset_size: int) -> list[int]:
    rng = self.rngs[self.config.stream_name]()
    indices = jax.random.choice(
        rng, dataset_size, shape=(self.config.num_samples,),
        replace=self.config.replacement
    )
    return indices.tolist()

get_batch_at ¤

get_batch_at(start: int | Array, size: int, key: Array | None = None) -> dict[str, Any]

Stateless indexed batch access; JIT-traceable for scan-based iteration.

Returns size records starting at logical position start. Does not advance self.index or any other internal state, so callers (typically Pipeline) can drive iteration via their own position counter and trace get_batch_at under nnx.scan / nnx.jit.

Two modes:

  • Sequential (self.is_random_order == False): returns the contiguous slice data[start : start + size] with wrap-around at the end of the source.
  • Shuffled (self.is_random_order == True): applies a deterministic permutation derived from key and returns the slice of that permutation. Same (start, size, key) always returns the same output. The permutation is materialized via jax.random.permutation(key, length) per call — O(length) per batch.

Parameters:

Name Type Description Default
start int | Array

Starting logical index; accepts concrete int or traced jax.Array.

required
size int

Number of records to return (Python int — JAX shapes are static).

required
key Array | None

PRNG key for shuffled mode. Required when is_random_order=True; ignored otherwise.

None

Returns:

Type Description
dict[str, Any]

Dict mapping each data key to a JAX array with leading

dict[str, Any]

dimension size.

element_spec ¤

element_spec() -> Any

Derive per-element spec from the eager dict-of-arrays storage.

EagerSourceBase subclasses store data as a dict mapping keys to arrays whose leading axis is the dataset size. This default implementation strips that leading axis from every leaf to produce one jax.ShapeDtypeStruct per key.

Subclasses with non-dict storage should override.

Raises:

Type Description
ValueError

If the source is empty.

get_batch ¤

get_batch(batch_size: int, key: Array | None = None) -> dict[str, Any]

Get one eager batch in stateful or stateless mode.

get_dataset_info ¤

get_dataset_info() -> Any

Return cached backend-specific dataset metadata.

reset ¤

reset(seed: int | None = None) -> None

Reset eager-source iteration state.

set_random_order ¤

set_random_order(enabled: bool) -> None

Update runtime random-order behavior.

TFDSStreamingSource ¤

TFDSStreamingSource(config: TFDSStreamingConfig, *, rngs: Rngs | None = None, name: str | None = None)

Bases: StreamingSourceBase

Streaming TFDS source for large datasets.

Thin wrapper around TF dataset for data that can't fit in memory. Uses DLPack for efficient conversion and fixed prefetch buffer.

Key Features
  • DLPack zero-copy for TF→JAX conversion
  • Fixed prefetch buffer (no AUTOTUNE thread storms)
  • Supports all TFDS datasets
  • include_keys/exclude_keys filtering
Trade-offs vs Eager
  • Cannot checkpoint mid-epoch (external iterator state)
  • Some TF thread overhead (minimized with fixed prefetch)
  • Use with Artifex train_epoch_streaming() for best results
Example
# Create streaming source for large dataset
config = TFDSStreamingConfig(name="imagenet2012", split="train")
source = TFDSStreamingSource(config, rngs=nnx.Rngs(0))

# Iterate with prefetching
for batch in prefetch_to_device(source, size=2):
    train_step(batch)

Parameters:

Name Type Description Default
config TFDSStreamingConfig

Configuration for the source

required
rngs Rngs | None

Optional RNG state

None
name str | None

Optional name (defaults to TFDSStreamingSource(dataset:split))

None

dataset_name instance-attribute ¤

dataset_name = name

split_name instance-attribute ¤

split_name = split

as_supervised instance-attribute ¤

as_supervised = as_supervised

include_keys instance-attribute ¤

include_keys = include_keys

exclude_keys instance-attribute ¤

exclude_keys = exclude_keys

length instance-attribute ¤

length: int | None = num_examples

epoch instance-attribute ¤

epoch = Variable(0)

config instance-attribute ¤

config = static(config)

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = stochastic

stream_name instance-attribute ¤

stream_name = stream_name

is_random_order property ¤

is_random_order: bool

Whether iteration order is randomized.

element_spec ¤

element_spec() -> Any

Return per-element shape/dtype derived by peeking the TFDS stream.

TFDS streams yield single-element dicts (or tuples in as_supervised mode). The spec is derived by peeking the first element from a fresh iterator on the cached self._tf_dataset (which has already been built in __init__) so it does not re-trigger downloads. Top-level dict values are treated as single leaves so vector features become 1-D ShapeDtypeStruct instead of per-scalar leaves.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

process ¤

process(input: Any, *args: Any, **kwargs: Any) -> Any

Process input structure.

This method transforms the structure/organization of input data without modifying the data values themselves.

Subclasses MUST implement this method.

The input/output types depend on the specific structural processor:

  • Batcher: list[Element] -> list[Batch]
  • Sampler: int -> list[int]
  • Sharder: Batch -> Sharded[Batch]
  • Splitter: Dataset -> tuple[Dataset, Dataset]

Parameters:

Name Type Description Default
input Any

Input to process (type varies by processor)

required
*args Any

Additional positional arguments (processor-specific)

()
**kwargs Any

Additional keyword arguments (processor-specific)

{}

Returns:

Type Description
Any

Processed output (type varies by processor)

Examples:

Batcher implementation:

def process(self, elements: list[Element]) -> list[Batch]:
    batches = []
    for i in range(0, len(elements), self.config.batch_size):
        batch_elements = elements[i:i + self.config.batch_size]
        batches.append(Batch.from_elements(batch_elements))
    return batches

Sampler implementation (deterministic):

def process(self, dataset_size: int) -> list[int]:
    return list(range(min(self.config.num_samples, dataset_size)))

Sampler implementation (stochastic):

def process(self, dataset_size: int) -> list[int]:
    rng = self.rngs[self.config.stream_name]()
    indices = jax.random.choice(
        rng, dataset_size, shape=(self.config.num_samples,),
        replace=self.config.replacement
    )
    return indices.tolist()

get_batch_at ¤

get_batch_at(start: int | Any, size: int, key: Any | None = None) -> Any

Stateless indexed batch access for Pipeline-driven iteration.

Returns size records starting at start. Implementations must be stateless (no mutation of internal counters) and JAX-traceable (must accept tracer values for start) so the call composes with nnx.scan.

Parameters:

Name Type Description Default
start int | Any

Starting index. Sources that support indexed access accept a Python int or a traced jax.Array.

required
size int

Number of records to return (Python int — JAX shapes are static).

required
key Any | None

Optional PRNG key for shuffled or stochastic sampling.

None

Returns:

Type Description
Any

A batch dict (or PyTree) with leading dim size.

Raises:

Type Description
NotImplementedError

If the source does not support indexed access (e.g. forward-only streams). Pipeline falls back to its __iter__ debug path in that case.

get_dataset_info ¤

get_dataset_info() -> Any

Return cached backend-specific dataset metadata.

reset ¤

reset(seed: int | None = None) -> None

Reset streaming iterator state.

get_batch ¤

get_batch(batch_size: int) -> dict[str, Any]

Collect up to batch_size items from the streaming iterator.