TensorFlow Datasets Source¤
TFDSEagerSource provides integration with TensorFlow Datasets (TFDS), giving access to hundreds of ready-to-use datasets with automatic conversion from TensorFlow tensors to JAX arrays.
Note: You can also use the factory function
from_tfds(name, split, ...)which auto-selects between eager and streaming modes based on your configuration.
Key Features¤
| Feature | Description |
|---|---|
| Automatic conversion | TensorFlow tensors → JAX arrays |
| Built-in prefetching | Uses tf.data.AUTOTUNE for performance |
| Supervised mode | Optional (image, label) tuple unpacking |
| Shuffling | TensorFlow-native shuffle with buffer |
| Caching | Optional dataset caching for repeated epochs |
★ Insight ─────────────────────────────────────
- TFDS handles download and preparation automatically
- Use
as_supervised=Trueto get{"image": ..., "label": ...}format - TensorFlow's prefetching is applied automatically for performance
- The source tracks epoch and index for stateful training loops
─────────────────────────────────────────────────
Installation¤
TFDSEagerSource requires TensorFlow and tensorflow-datasets:
Quick Start¤
import flax.nnx as nnx
from datarax.sources import TFDSEagerSource
from datarax.sources.tfds_source import TFDSEagerConfig
# Load MNIST dataset
config = TFDSEagerConfig(name="mnist", split="train")
source = TFDSEagerSource(config, rngs=nnx.Rngs(0))
# Iterate over elements
for item in source:
image = item["image"] # JAX array, shape (28, 28, 1)
label = item["label"] # JAX array, scalar
process(image, label)
Supervised Mode¤
Get a cleaner {"image", "label"} structure:
config = TFDSEagerConfig(
name="cifar10",
split="train",
as_supervised=True, # Returns {"image": ..., "label": ...}
)
source = TFDSEagerSource(config)
batch = source.get_batch(32)
images = batch["image"] # Shape: (32, 32, 32, 3)
labels = batch["label"] # Shape: (32,)
Batch Retrieval¤
For training loops with automatic epoch cycling:
# Stateful batch retrieval
for step in range(10000):
batch = source.get_batch(64)
loss = train_step(batch)
# Check progress
print(f"Epoch {source.epoch.get_value()}, "
f"Index {source.index.get_value()}")
Shuffling¤
Enable shuffling with configurable buffer:
config = TFDSEagerConfig(
name="imagenet2012",
split="train",
shuffle=True,
shuffle_buffer_size=10000,
seed=42,
)
source = TFDSEagerSource(config, rngs=nnx.Rngs(42))
Caching¤
Cache the dataset for faster repeated epochs:
config = TFDSEagerConfig(
name="mnist",
split="train",
cacheable=True, # Cache in memory after first epoch
)
source = TFDSEagerSource(config)
Custom Data Directory¤
Store datasets in a specific location:
config = TFDSEagerConfig(
name="imagenet2012",
split="train",
data_dir="/path/to/tfds_data",
)
source = TFDSEagerSource(config)
Field Filtering¤
Select only needed fields:
config = TFDSEagerConfig(
name="coco/2017",
split="train",
include_keys={"image", "objects"}, # Only these fields
)
# Or exclude unwanted fields
config = TFDSEagerConfig(
name="mnist",
split="train",
exclude_keys={"id"},
)
Integration with DAG Pipelines¤
from datarax.pipeline import Pipeline
from datarax.pipeline import Pipeline
config = TFDSEagerConfig(
name="cifar10",
split="train",
as_supervised=True,
)
source = TFDSEagerSource(config)
pipeline = (
Pipeline(source=source, stages=[normalize_op, augment_op], batch_size=128, rngs=nnx.Rngs(0)))
for batch in pipeline:
train_step(batch)
Dataset Information¤
Access rich metadata from TFDS:
info = source.get_dataset_info()
print(f"Description: {info.description}")
print(f"Features: {info.features}")
print(f"Splits: {list(info.splits.keys())}")
print(f"Citation: {info.citation}")
# Number of examples
print(f"Train examples: {info.splits['train'].num_examples}")
See Also¤
- Data Sources Guide - Complete data loading guide
- HF Source - HuggingFace Datasets integration
- TFDS Quick Reference
- TFDS Catalog
API Reference¤
datarax.sources.tfds_source ¤
TensorFlow Datasets (TFDS) data source implementation for Datarax.
This module provides two distinct source types optimized for different use cases:
TFDSEagerSource: For small/medium datasets that fit in memory (~10% VRAM) - Loads ALL data to JAX arrays at initialization - Pure JAX iteration after init (no TensorFlow overhead during training) - O(1) memory shuffling via Grain's index_shuffle (Feistel cipher) - Fully checkpointable (just indices, no external state) - Ideal for: MNIST, CIFAR-10, Fashion-MNIST, small custom datasets
TFDSStreamingSource: For large datasets that don't fit in memory - Thin wrapper around TF dataset iterator - DLPack zero-copy conversion for each batch - Fixed prefetch buffer (no AUTOTUNE thread storms) - Trade-offs: External iterator state, can't checkpoint mid-epoch - Ideal for: ImageNet, large-scale datasets, memory-constrained environments
Architecture Insight
The ~0.4s delay at epoch 2 in previous implementations was caused by TensorFlow's AUTOTUNE prefetch spawning background threads during epoch transitions. TFDSEagerSource eliminates this entirely by loading all data upfront. TFDSStreamingSource uses fixed prefetch to prevent thread storms.
TFDSEagerConfig
dataclass
¤
TFDSEagerConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, name: str | None = None, split: str | None = None, data_dir: str | None = None, include_keys: set[str] | None = None, exclude_keys: set[str] | None = None, try_gcs: bool = False, shuffle: bool = False, seed: int = 42, as_supervised: bool = False, download_and_prepare_kwargs: dict[str, Any] | None = None, beam_num_workers: int | None = None, local_files_only: bool = False)
Bases: SourceConfigBase
Configuration for TFDSEagerSource (loads all data to JAX at init).
Configuration for eager-loading TensorFlow Datasets into JAX arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str | None
|
Name of the dataset in TFDS (required) |
None
|
split
|
str | None
|
Split of the dataset to load, e.g., "train", "test" (required) |
None
|
data_dir
|
str | None
|
Optional directory where the dataset is stored/downloaded |
None
|
shuffle
|
bool
|
Whether to shuffle the dataset during iteration |
False
|
seed
|
int
|
Integer seed for Grain's index_shuffle (default: 42) |
42
|
as_supervised
|
bool
|
If True, returns 'image'/'label' keys instead of original features |
False
|
download_and_prepare_kwargs
|
dict[str, Any] | None
|
Optional keyword arguments for download_and_prepare |
None
|
include_keys
|
set[str] | None
|
Optional set of keys to include in output (exclusive with exclude_keys) |
None
|
exclude_keys
|
set[str] | None
|
Optional set of keys to exclude from output (exclusive with include_keys) |
None
|
Note
The seed parameter is an integer (not JAX RNG key) for Grain's index_shuffle. This ensures O(1) memory shuffling and reproducible per-epoch seeds.
TFDSStreamingConfig
dataclass
¤
TFDSStreamingConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, name: str | None = None, split: str | None = None, data_dir: str | None = None, include_keys: set[str] | None = None, exclude_keys: set[str] | None = None, try_gcs: bool = False, shuffle: bool = False, shuffle_buffer_size: int = 1000, as_supervised: bool = False, download_and_prepare_kwargs: dict[str, Any] | None = None, beam_num_workers: int | None = None, prefetch_buffer: int = 2, local_files_only: bool = False)
Bases: SourceConfigBase
Configuration for TFDSStreamingSource (streams data from TF dataset).
Use this for datasets too large to fit in memory. The streaming source uses fixed prefetch buffers to avoid AUTOTUNE thread storms.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str | None
|
Name of the dataset in TFDS (required) |
None
|
split
|
str | None
|
Split of the dataset to load, e.g., "train", "test" (required) |
None
|
data_dir
|
str | None
|
Optional directory where the dataset is stored/downloaded |
None
|
shuffle
|
bool
|
Whether to shuffle the dataset |
False
|
shuffle_buffer_size
|
int
|
TF shuffle buffer size (default: 1000) |
1000
|
as_supervised
|
bool
|
If True, returns 'image'/'label' keys |
False
|
download_and_prepare_kwargs
|
dict[str, Any] | None
|
Optional keyword arguments for download_and_prepare |
None
|
include_keys
|
set[str] | None
|
Optional set of keys to include in output |
None
|
exclude_keys
|
set[str] | None
|
Optional set of keys to exclude from output |
None
|
prefetch_buffer
|
int
|
Fixed prefetch buffer size (default: 2, NOT AUTOTUNE) |
2
|
Note
The prefetch_buffer uses a fixed size instead of TF AUTOTUNE to prevent background thread storms that cause delays during epoch transitions.
TFDSEagerSource ¤
TFDSEagerSource(config: TFDSEagerConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: EagerSourceBase
Eager-loading TFDS source for small/medium datasets.
Loads ALL data to JAX arrays at initialization, then operates like a MemorySource with pure JAX operations. Use for datasets that fit in ~10% of device VRAM.
Key Features
- One-time TF→JAX conversion at init (DLPack zero-copy when possible)
- Pure JAX iteration after init (no TF threads during training)
- O(1) memory shuffling via Grain's index_shuffle (Feistel cipher)
- Full checkpointing support (indices only, no external state)
- Supports
as_supervisedmode and key filtering
Performance
- Eliminates ~0.4s epoch 2 delay from TF AUTOTUNE threads
- Training loops can use lax.fori_loop for 100-500x speedup
- Device placement via collect_to_array() for staged training
Example
# Create eager source for MNIST
config = TFDSEagerConfig(name="mnist", split="train", shuffle=True)
source = TFDSEagerSource(config, rngs=nnx.Rngs(0))
# Iterate - pure JAX, no TF overhead
for item in source:
process(item["image"])
# Get batch (stateless with key, or stateful without)
batch = source.get_batch(32) # Stateful
batch = source.get_batch(32, key=jax.random.key(0)) # Stateless
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TFDSEagerConfig
|
Configuration for the source |
required |
rngs
|
Rngs | None
|
Optional RNG state for shuffling |
None
|
name
|
str | None
|
Optional name (defaults to TFDSEagerSource(dataset:split)) |
None
|
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
process ¤
Process input structure.
This method transforms the structure/organization of input data without modifying the data values themselves.
Subclasses MUST implement this method.
The input/output types depend on the specific structural processor:
- Batcher: list[Element] -> list[Batch]
- Sampler: int -> list[int]
- Sharder: Batch -> Sharded[Batch]
- Splitter: Dataset -> tuple[Dataset, Dataset]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Any
|
Input to process (type varies by processor) |
required |
*args
|
Any
|
Additional positional arguments (processor-specific) |
()
|
**kwargs
|
Any
|
Additional keyword arguments (processor-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Processed output (type varies by processor) |
Examples:
Batcher implementation:
def process(self, elements: list[Element]) -> list[Batch]:
batches = []
for i in range(0, len(elements), self.config.batch_size):
batch_elements = elements[i:i + self.config.batch_size]
batches.append(Batch.from_elements(batch_elements))
return batches
Sampler implementation (deterministic):
def process(self, dataset_size: int) -> list[int]:
return list(range(min(self.config.num_samples, dataset_size)))
Sampler implementation (stochastic):
get_batch_at ¤
Stateless indexed batch access; JIT-traceable for scan-based iteration.
Returns size records starting at logical position start.
Does not advance self.index or any other internal state, so
callers (typically Pipeline) can drive iteration via their
own position counter and trace get_batch_at under
nnx.scan / nnx.jit.
Two modes:
- Sequential (
self.is_random_order == False): returns the contiguous slicedata[start : start + size]with wrap-around at the end of the source. - Shuffled (
self.is_random_order == True): applies a deterministic permutation derived fromkeyand returns the slice of that permutation. Same(start, size, key)always returns the same output. The permutation is materialized viajax.random.permutation(key, length)per call — O(length) per batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
int | Array
|
Starting logical index; accepts concrete int or
traced |
required |
size
|
int
|
Number of records to return (Python int — JAX shapes are static). |
required |
key
|
Array | None
|
PRNG key for shuffled mode. Required when
|
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dict mapping each data key to a JAX array with leading |
dict[str, Any]
|
dimension |
element_spec ¤
element_spec() -> Any
Derive per-element spec from the eager dict-of-arrays storage.
EagerSourceBase subclasses store data as a dict mapping keys to arrays
whose leading axis is the dataset size. This default implementation
strips that leading axis from every leaf to produce one
jax.ShapeDtypeStruct per key.
Subclasses with non-dict storage should override.
Raises:
| Type | Description |
|---|---|
ValueError
|
If the source is empty. |
get_batch ¤
Get one eager batch in stateful or stateless mode.
TFDSStreamingSource ¤
TFDSStreamingSource(config: TFDSStreamingConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: StreamingSourceBase
Streaming TFDS source for large datasets.
Thin wrapper around TF dataset for data that can't fit in memory. Uses DLPack for efficient conversion and fixed prefetch buffer.
Key Features
- DLPack zero-copy for TF→JAX conversion
- Fixed prefetch buffer (no AUTOTUNE thread storms)
- Supports all TFDS datasets
- include_keys/exclude_keys filtering
Trade-offs vs Eager
- Cannot checkpoint mid-epoch (external iterator state)
- Some TF thread overhead (minimized with fixed prefetch)
- Use with Artifex train_epoch_streaming() for best results
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TFDSStreamingConfig
|
Configuration for the source |
required |
rngs
|
Rngs | None
|
Optional RNG state |
None
|
name
|
str | None
|
Optional name (defaults to TFDSStreamingSource(dataset:split)) |
None
|
element_spec ¤
element_spec() -> Any
Return per-element shape/dtype derived by peeking the TFDS stream.
TFDS streams yield single-element dicts (or tuples in
as_supervised mode). The spec is derived by peeking the first
element from a fresh iterator on the cached self._tf_dataset
(which has already been built in __init__) so it does not
re-trigger downloads. Top-level dict values are treated as single
leaves so vector features become 1-D ShapeDtypeStruct instead of
per-scalar leaves.
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
process ¤
Process input structure.
This method transforms the structure/organization of input data without modifying the data values themselves.
Subclasses MUST implement this method.
The input/output types depend on the specific structural processor:
- Batcher: list[Element] -> list[Batch]
- Sampler: int -> list[int]
- Sharder: Batch -> Sharded[Batch]
- Splitter: Dataset -> tuple[Dataset, Dataset]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Any
|
Input to process (type varies by processor) |
required |
*args
|
Any
|
Additional positional arguments (processor-specific) |
()
|
**kwargs
|
Any
|
Additional keyword arguments (processor-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Processed output (type varies by processor) |
Examples:
Batcher implementation:
def process(self, elements: list[Element]) -> list[Batch]:
batches = []
for i in range(0, len(elements), self.config.batch_size):
batch_elements = elements[i:i + self.config.batch_size]
batches.append(Batch.from_elements(batch_elements))
return batches
Sampler implementation (deterministic):
def process(self, dataset_size: int) -> list[int]:
return list(range(min(self.config.num_samples, dataset_size)))
Sampler implementation (stochastic):
get_batch_at ¤
Stateless indexed batch access for Pipeline-driven iteration.
Returns size records starting at start. Implementations must
be stateless (no mutation of internal counters) and JAX-traceable
(must accept tracer values for start) so the call composes with
nnx.scan.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
int | Any
|
Starting index. Sources that support indexed access
accept a Python int or a traced |
required |
size
|
int
|
Number of records to return (Python int — JAX shapes are static). |
required |
key
|
Any | None
|
Optional PRNG key for shuffled or stochastic sampling. |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
A batch dict (or PyTree) with leading dim |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If the source does not support indexed
access (e.g. forward-only streams). Pipeline falls back to
its |