HuggingFace Source¤
HFEagerSource provides seamless integration with HuggingFace Datasets, allowing you to load any of the 100,000+ datasets available on the Hub directly into your Datarax pipelines with automatic JAX array conversion.
Note: You can also use the factory function
from_hf(name, split, ...)which auto-selects between eager and streaming modes based on your configuration.
Key Features¤
| Feature | Description |
|---|---|
| Automatic conversion | TensorFlow/NumPy tensors → JAX arrays |
| Streaming support | Load large datasets without downloading everything |
| Shuffling | Built-in shuffle with configurable buffer size |
| Key filtering | Include/exclude specific dataset fields |
| Stateful iteration | Track position, epoch, and support batch retrieval |
★ Insight ─────────────────────────────────────
- HFEagerSource wraps the
datasetslibrary for JAX-native workflows - PIL images are automatically converted to JAX arrays
- Use
streaming=Truefor datasets larger than your disk - The
get_batch()method enables efficient batch retrieval
─────────────────────────────────────────────────
Installation¤
HFEagerSource requires the datasets package:
Quick Start¤
import flax.nnx as nnx
from datarax.sources import HFEagerSource
from datarax.sources.hf_source import HFEagerConfig
# Load IMDB sentiment dataset
config = HFEagerConfig(name="imdb", split="train")
source = HFEagerSource(config, rngs=nnx.Rngs(0))
# Iterate over elements
for item in source:
text = item["text"]
label = item["label"]
process(text, label)
Batch Retrieval¤
For training loops, use the stateful get_batch() method:
# Get batches of 32 samples
batch = source.get_batch(32)
# batch["text"] has shape (32,)
# batch["label"] has shape (32,)
# Automatic epoch cycling
for step in range(1000):
batch = source.get_batch(32)
train_step(batch)
Streaming Large Datasets¤
For datasets too large to download completely:
config = HFEagerConfig(
name="c4", # Common Crawl dataset (800GB+)
split="train",
streaming=True,
)
source = HFEagerSource(config)
# Data is fetched on-demand
for item in source:
process(item)
Shuffling¤
Enable shuffling with configurable buffer size:
config = HFEagerConfig(
name="mnist",
split="train",
shuffle=True,
shuffle_buffer_size=10000, # Buffer for streaming shuffle
seed=42,
)
source = HFEagerSource(config, rngs=nnx.Rngs(42))
Field Filtering¤
Select only the fields you need:
# Include only specific fields
config = HFEagerConfig(
name="glue",
split="train",
include_keys={"sentence", "label"}, # Only these fields
)
# Or exclude unwanted fields
config = HFEagerConfig(
name="imdb",
split="train",
exclude_keys={"idx"}, # Everything except idx
)
Integration with DAG Pipelines¤
from datarax.pipeline import Pipeline
from datarax.pipeline import Pipeline
# Build a pipeline
config = HFEagerConfig(name="mnist", split="train")
source = HFEagerSource(config)
pipeline = (
Pipeline(source=source, stages=[normalize_op, augment_op], batch_size=64, rngs=nnx.Rngs(0)))
for batch in pipeline:
train_step(batch)
Dataset Information¤
Access metadata about the loaded dataset:
# Get HuggingFace DatasetInfo
info = source.get_dataset_info()
print(f"Description: {info.description}")
print(f"Features: {info.features}")
# Check length (if available)
print(f"Dataset length: {len(source)}")
See Also¤
- Data Sources Guide - Detailed data loading guide
- TFDS Source - TensorFlow Datasets integration
- HuggingFace Quick Reference
- HuggingFace Tutorial
API Reference¤
datarax.sources.hf_source ¤
HuggingFace Datasets data source implementation for Datarax.
This module provides two distinct source types optimized for different use cases:
HFEagerSource: For small/medium datasets that fit in memory (~10% VRAM) - Loads ALL data to JAX arrays at initialization - Pure JAX iteration after init (no HuggingFace overhead during training) - O(1) memory shuffling via Grain's index_shuffle (Feistel cipher) - Fully checkpointable (just indices, no external state) - Ideal for: MNIST, CIFAR-10, sentiment datasets, small custom datasets
HFStreamingSource: For large datasets that don't fit in memory - Thin wrapper around HuggingFace dataset iterator - Supports HuggingFace's built-in streaming mode - Trade-offs: External iterator state, can't checkpoint mid-epoch - Ideal for: The Pile, C4, large-scale datasets, memory-constrained environments
Architecture Insight
The separation between eager and streaming follows the same pattern as TFDS, ensuring consistent behavior across data backends while optimizing for each use case's specific requirements.
HFEagerConfig
dataclass
¤
HFEagerConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, name: str | None = None, split: str | None = None, data_dir: str | None = None, include_keys: set[str] | None = None, exclude_keys: set[str] | None = None, shuffle: bool = False, seed: int = 42, download_kwargs: dict[str, Any] | None = None, local_files_only: bool = False)
Bases: SourceConfigBase
Configuration for HFEagerSource (loads all data to JAX at init).
Configuration for eager-loading HuggingFace datasets into JAX arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str | None
|
Name of the dataset in HuggingFace Hub (required) |
None
|
split
|
str | None
|
Split of the dataset to load, e.g., "train", "test" (required) |
None
|
data_dir
|
str | None
|
Optional directory where the dataset is stored/downloaded |
None
|
shuffle
|
bool
|
Whether to shuffle the dataset during iteration |
False
|
seed
|
int
|
Integer seed for Grain's index_shuffle (default: 42) |
42
|
download_kwargs
|
dict[str, Any] | None
|
Optional keyword arguments for load_dataset |
None
|
include_keys
|
set[str] | None
|
Optional set of keys to include in output (exclusive with exclude_keys) |
None
|
exclude_keys
|
set[str] | None
|
Optional set of keys to exclude from output (exclusive with include_keys) |
None
|
Note
The seed parameter is an integer (not JAX RNG key) for Grain's index_shuffle. This ensures O(1) memory shuffling and reproducible per-epoch seeds.
precomputed_stats
class-attribute
instance-attribute
¤
HFStreamingConfig
dataclass
¤
HFStreamingConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, name: str | None = None, split: str | None = None, data_dir: str | None = None, include_keys: set[str] | None = None, exclude_keys: set[str] | None = None, streaming: bool = False, shuffle: bool = False, shuffle_buffer_size: int = 1000, download_kwargs: dict[str, Any] | None = None, local_files_only: bool = False)
Bases: SourceConfigBase
Configuration for HFStreamingSource (streams data from HF dataset).
Use this for datasets too large to fit in memory or when using HuggingFace's built-in streaming mode for efficient data loading.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str | None
|
Name of the dataset in HuggingFace Hub (required) |
None
|
split
|
str | None
|
Split of the dataset to load, e.g., "train", "test" (required) |
None
|
data_dir
|
str | None
|
Optional directory where the dataset is stored/downloaded |
None
|
streaming
|
bool
|
Whether to use HuggingFace streaming mode (default: False) |
False
|
shuffle
|
bool
|
Whether to shuffle the dataset |
False
|
shuffle_buffer_size
|
int
|
Buffer size for shuffling in streaming mode (default: 1000) |
1000
|
download_kwargs
|
dict[str, Any] | None
|
Optional keyword arguments for load_dataset |
None
|
include_keys
|
set[str] | None
|
Optional set of keys to include in output |
None
|
exclude_keys
|
set[str] | None
|
Optional set of keys to exclude from output |
None
|
precomputed_stats
class-attribute
instance-attribute
¤
HFEagerSource ¤
HFEagerSource(config: HFEagerConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: EagerSourceBase
Eager-loading HuggingFace source for small/medium datasets.
Loads ALL data to JAX arrays at initialization, then operates like a MemorySource with pure JAX operations. Use for datasets that fit in ~10% of device VRAM.
Key Features
- One-time conversion at init (PIL→numpy→JAX for images)
- Pure JAX iteration after init
- O(1) memory shuffling via Grain's index_shuffle (Feistel cipher)
- Full checkpointing support (indices only, no external state)
- Automatic PIL Image to JAX array conversion
Performance
- Training loops can use lax.fori_loop for 100-500x speedup
- Device placement via collect_to_array() for staged training
Example
# Create eager source for MNIST from HuggingFace
config = HFEagerConfig(name="mnist", split="train", shuffle=True)
source = HFEagerSource(config, rngs=nnx.Rngs(0))
# Iterate - pure JAX, no HF overhead
for item in source:
process(item["image"])
# Get batch (stateless with key, or stateful without)
batch = source.get_batch(32) # Stateful
batch = source.get_batch(32, key=jax.random.key(0)) # Stateless
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
HFEagerConfig
|
Configuration for the source |
required |
rngs
|
Rngs | None
|
Optional RNG state for shuffling |
None
|
name
|
str | None
|
Optional name (defaults to HFEagerSource(dataset:split)) |
None
|
Raises:
| Type | Description |
|---|---|
ImportError
|
If the datasets package is not installed |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
process ¤
Process input structure.
This method transforms the structure/organization of input data without modifying the data values themselves.
Subclasses MUST implement this method.
The input/output types depend on the specific structural processor:
- Batcher: list[Element] -> list[Batch]
- Sampler: int -> list[int]
- Sharder: Batch -> Sharded[Batch]
- Splitter: Dataset -> tuple[Dataset, Dataset]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Any
|
Input to process (type varies by processor) |
required |
*args
|
Any
|
Additional positional arguments (processor-specific) |
()
|
**kwargs
|
Any
|
Additional keyword arguments (processor-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Processed output (type varies by processor) |
Examples:
Batcher implementation:
def process(self, elements: list[Element]) -> list[Batch]:
batches = []
for i in range(0, len(elements), self.config.batch_size):
batch_elements = elements[i:i + self.config.batch_size]
batches.append(Batch.from_elements(batch_elements))
return batches
Sampler implementation (deterministic):
def process(self, dataset_size: int) -> list[int]:
return list(range(min(self.config.num_samples, dataset_size)))
Sampler implementation (stochastic):
get_batch_at ¤
Stateless indexed batch access; JIT-traceable for scan-based iteration.
Returns size records starting at logical position start.
Does not advance self.index or any other internal state, so
callers (typically Pipeline) can drive iteration via their
own position counter and trace get_batch_at under
nnx.scan / nnx.jit.
Two modes:
- Sequential (
self.is_random_order == False): returns the contiguous slicedata[start : start + size]with wrap-around at the end of the source. - Shuffled (
self.is_random_order == True): applies a deterministic permutation derived fromkeyand returns the slice of that permutation. Same(start, size, key)always returns the same output. The permutation is materialized viajax.random.permutation(key, length)per call — O(length) per batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
int | Array
|
Starting logical index; accepts concrete int or
traced |
required |
size
|
int
|
Number of records to return (Python int — JAX shapes are static). |
required |
key
|
Array | None
|
PRNG key for shuffled mode. Required when
|
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dict mapping each data key to a JAX array with leading |
dict[str, Any]
|
dimension |
element_spec ¤
element_spec() -> Any
Derive per-element spec from the eager dict-of-arrays storage.
EagerSourceBase subclasses store data as a dict mapping keys to arrays
whose leading axis is the dataset size. This default implementation
strips that leading axis from every leaf to produce one
jax.ShapeDtypeStruct per key.
Subclasses with non-dict storage should override.
Raises:
| Type | Description |
|---|---|
ValueError
|
If the source is empty. |
get_batch ¤
Get one eager batch in stateful or stateless mode.
HFStreamingSource ¤
HFStreamingSource(config: HFStreamingConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: StreamingSourceBase
Streaming HuggingFace source for large datasets.
Thin wrapper around HuggingFace dataset for data that can't fit in memory. Supports HuggingFace's native streaming mode for efficient large-scale data loading.
Key Features
- Native HuggingFace streaming support
- Automatic PIL Image to JAX array conversion
- include_keys/exclude_keys filtering
- Revision pinning for security (B615)
Trade-offs vs Eager
- Cannot checkpoint mid-epoch (external iterator state)
- Use with prefetch_to_device() for best results
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
HFStreamingConfig
|
Configuration for the source |
required |
rngs
|
Rngs | None
|
Optional RNG state |
None
|
name
|
str | None
|
Optional name (defaults to HFStreamingSource(dataset:split)) |
None
|
Raises:
| Type | Description |
|---|---|
ImportError
|
If the datasets package is not installed |
random_order_buffer_depth
property
¤
random_order_buffer_depth: int
Shuffle buffer size from source config.
selected_keys
property
¤
Optional key-include filter from source config.
rejected_keys
property
¤
Optional key-exclude filter from source config.
element_spec ¤
element_spec() -> Any
Return per-element shape/dtype derived by peeking the backend.
Streaming sources cannot strip a leading dataset-size dimension because
each iteration yields one element. The spec is derived by peeking the
first element from the underlying HuggingFace dataset (without
consuming the iterator state for normal training) and converting each
top-level value into a single ShapeDtypeStruct.
Top-level dict values are treated as single arrays (HuggingFace commonly emits Python lists for vector features; those become 1-D arrays, not nested per-element scalars).
The peek operates on the cached self._hf_dataset (already loaded
in __init__) so it does not re-trigger downloads and is safe to
call repeatedly.
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
process ¤
Process input structure.
This method transforms the structure/organization of input data without modifying the data values themselves.
Subclasses MUST implement this method.
The input/output types depend on the specific structural processor:
- Batcher: list[Element] -> list[Batch]
- Sampler: int -> list[int]
- Sharder: Batch -> Sharded[Batch]
- Splitter: Dataset -> tuple[Dataset, Dataset]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Any
|
Input to process (type varies by processor) |
required |
*args
|
Any
|
Additional positional arguments (processor-specific) |
()
|
**kwargs
|
Any
|
Additional keyword arguments (processor-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Processed output (type varies by processor) |
Examples:
Batcher implementation:
def process(self, elements: list[Element]) -> list[Batch]:
batches = []
for i in range(0, len(elements), self.config.batch_size):
batch_elements = elements[i:i + self.config.batch_size]
batches.append(Batch.from_elements(batch_elements))
return batches
Sampler implementation (deterministic):
def process(self, dataset_size: int) -> list[int]:
return list(range(min(self.config.num_samples, dataset_size)))
Sampler implementation (stochastic):
get_batch_at ¤
Stateless indexed batch access for Pipeline-driven iteration.
Returns size records starting at start. Implementations must
be stateless (no mutation of internal counters) and JAX-traceable
(must accept tracer values for start) so the call composes with
nnx.scan.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
int | Any
|
Starting index. Sources that support indexed access
accept a Python int or a traced |
required |
size
|
int
|
Number of records to return (Python int — JAX shapes are static). |
required |
key
|
Any | None
|
Optional PRNG key for shuffled or stochastic sampling. |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
A batch dict (or PyTree) with leading dim |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If the source does not support indexed
access (e.g. forward-only streams). Pipeline falls back to
its |