Memory Source¤
In-memory data source for testing and small datasets.
See Also¤
- Sources Overview - All data sources
- Data Sources Guide - In-depth guide
- Simple Pipeline Example
- HF Source - For larger datasets
datarax.sources.memory_source ¤
In-memory data source implementation for Datarax.
This module provides a data source that serves data from in-memory collections with support for both stateless and stateful operation modes.
MemorySourceConfig
dataclass
¤
MemorySourceConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, shuffle: bool = False, cache_size: int = 0, prefetch_size: int = 0, track_metadata: bool = False, shard_id: int | None = None, num_workers: int = 1)
Bases: StructuralConfig
Configuration for MemorySource (in-memory data source).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shuffle
|
bool
|
Whether to shuffle data on each epoch |
False
|
cache_size
|
int
|
Number of batches to cache (0 = no caching) |
0
|
prefetch_size
|
int
|
Number of items to prefetch (0 = no prefetching) |
0
|
track_metadata
|
bool
|
Whether to track metadata for each record |
False
|
shard_id
|
int | None
|
Optional shard identifier for distributed processing |
None
|
num_workers
|
int
|
Number of parallel workers (default 1). When > 1, each worker (identified by shard_id) receives a disjoint partition of the globally-shuffled elements. Worker k gets elements at global positions [k::num_workers]. |
1
|
precomputed_stats
class-attribute
instance-attribute
¤
MemorySource ¤
MemorySource(config: MemorySourceConfig, data: dict[str, Any] | list[Any] | Sequence[Any], *, rngs: Rngs | None = None, name: str | None = None)
Bases: DataSourceModule
In-memory data source for Datarax.
This data source serves data from in-memory collections and supports both stateless and stateful operation modes.
Key Features:
- Dual-mode operation (stateless iteration and stateful with internal index)
- Random access via __getitem__
- Optional shuffling with RNG support
- Batch retrieval with get_batch method
- Support for dictionary and list/sequence data
- Batch-first design for efficient processing
Examples:
Create source with list data:
# Create source with list data
data = [{'x': i, 'y': i*2} for i in range(100)]
config = MemorySourceConfig(shuffle=False)
source = MemorySource(config, data, rngs=nnx.Rngs(0))
# Stateless iteration
for item in source:
process(item)
# Stateful iteration with internal index
batch = source.get_batch(32) # Gets next 32 items
# With shuffling
config = MemorySourceConfig(shuffle=True)
source = MemorySource(config, data, rngs=nnx.Rngs(0))
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MemorySourceConfig
|
Configuration for the MemorySource |
required |
data
|
dict[str, Any] | list[Any] | Sequence[Any]
|
Either a dictionary mapping keys to data arrays or a list/sequence of elements. If a dictionary is provided, all values must have the same first dimension size. |
required |
rngs
|
Rngs | None
|
Optional RNG state for shuffling and stateful iteration |
None
|
name
|
str | None
|
Optional name for the module (defaults to "MemorySource") |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If dictionary data has inconsistent lengths |
TypeError
|
If data is a string |
metadata_manager
instance-attribute
¤
metadata_manager = MetadataManager(rngs=rngs, track_batches=True, shard_id=shard_id)
has_metadata
property
¤
has_metadata: bool
Check if this source is tracking metadata.
Returns:
| Type | Description |
|---|---|
bool
|
True if metadata tracking is enabled, False otherwise |
get_batch ¤
Get next batch of data.
This method supports both stateless (with explicit key) and stateful (with internal index tracking) operation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size
|
int
|
Number of elements in the batch |
required |
key
|
Array | None
|
Optional RNG key for shuffling (stateless mode) |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
Batch of data with shape (batch_size, ...) |
get_batch_at ¤
Stateless indexed batch access; JIT-traceable for scan-based iteration.
Returns size records starting at logical position start.
Does not advance self.index or any other internal state, so
callers can drive iteration via their own nnx.Variable position
counter and trace get_batch_at under nnx.scan / nnx.jit.
Two modes:
- Sequential (
MemorySourceConfig(shuffle=False)): returns the contiguous slicedata[start : start + size]with wrap-around at the end of the source. - Shuffled (
MemorySourceConfig(shuffle=True)): applies a deterministic permutation derived fromkeyand returns the slice of that permutation. Same(start, size, key)always returns the same output; differentkeyyields a different permutation. The permutation is materialized viajax.random.permutation(key, length)per call — O(length) per batch. Future optimization: switch to a Feistel-network PRP for O(1) per-element shuffled lookup at large dataset sizes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
int | Array
|
Starting logical index (inclusive); accepts concrete
int or traced |
required |
size
|
int
|
Number of records to return (must be a Python int — JAX shapes are static). |
required |
key
|
Array | None
|
PRNG key for shuffled mode. Required when the source is
configured with |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
Batch dict with leading dim |
reset ¤
reset(seed: int | None = None) -> None
Reset the source to the beginning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int | None
|
Optional seed for reproducibility (ignored for MemorySource) |
None
|
set_random_order ¤
set_random_order(enabled: bool) -> None
Enable or disable random-order iteration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
enabled
|
bool
|
Whether to randomize iteration order. |
required |
get_with_metadata ¤
get_with_metadata(index: int) -> tuple[Any, RecordMetadata]
Get element at specific index with its metadata.
This method is only available when track_metadata=True was set during initialization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
int
|
Index of element to retrieve |
required |
Returns:
| Type | Description |
|---|---|
tuple[Any, RecordMetadata]
|
Tuple of (data_element, metadata) |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If metadata tracking is not enabled |
IndexError
|
If index is out of bounds |
get_batch_with_metadata ¤
get_batch_with_metadata(batch_size: int, key: Array | None = None) -> tuple[Any, list[RecordMetadata]]
Get next batch of data with metadata for each element.
This method is only available when track_metadata=True was set during initialization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size
|
int
|
Number of elements in the batch |
required |
key
|
Array | None
|
Optional RNG key for shuffling (stateless mode) |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Any, list[RecordMetadata]]
|
Tuple of (batch_data, list_of_metadata) |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If metadata tracking is not enabled |
element_spec ¤
element_spec() -> Any
Return per-element shape/dtype derived from the in-memory data.
Dict-mode sources strip the leading dataset-size dimension from every
array to produce one jax.ShapeDtypeStruct per key. List-mode
sources introspect element 0 and apply jax.tree.map to produce a
matching PyTree of ShapeDtypeStruct leaves.
Returns:
| Type | Description |
|---|---|
Any
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If the source is empty (no element to introspect). |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
process ¤
Process input structure.
This method transforms the structure/organization of input data without modifying the data values themselves.
Subclasses MUST implement this method.
The input/output types depend on the specific structural processor:
- Batcher: list[Element] -> list[Batch]
- Sampler: int -> list[int]
- Sharder: Batch -> Sharded[Batch]
- Splitter: Dataset -> tuple[Dataset, Dataset]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Any
|
Input to process (type varies by processor) |
required |
*args
|
Any
|
Additional positional arguments (processor-specific) |
()
|
**kwargs
|
Any
|
Additional keyword arguments (processor-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Processed output (type varies by processor) |
Examples:
Batcher implementation:
def process(self, elements: list[Element]) -> list[Batch]:
batches = []
for i in range(0, len(elements), self.config.batch_size):
batch_elements = elements[i:i + self.config.batch_size]
batches.append(Batch.from_elements(batch_elements))
return batches
Sampler implementation (deterministic):
def process(self, dataset_size: int) -> list[int]:
return list(range(min(self.config.num_samples, dataset_size)))
Sampler implementation (stochastic):