Skip to content

Memory Source¤

In-memory data source for testing and small datasets.

See Also¤


datarax.sources.memory_source ¤

In-memory data source implementation for Datarax.

This module provides a data source that serves data from in-memory collections with support for both stateless and stateful operation modes.

logger module-attribute ¤

logger = getLogger(__name__)

MemorySourceConfig dataclass ¤

MemorySourceConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, shuffle: bool = False, cache_size: int = 0, prefetch_size: int = 0, track_metadata: bool = False, shard_id: int | None = None, num_workers: int = 1)

Bases: StructuralConfig

Configuration for MemorySource (in-memory data source).

Parameters:

Name Type Description Default
shuffle bool

Whether to shuffle data on each epoch

False
cache_size int

Number of batches to cache (0 = no caching)

0
prefetch_size int

Number of items to prefetch (0 = no prefetching)

0
track_metadata bool

Whether to track metadata for each record

False
shard_id int | None

Optional shard identifier for distributed processing

None
num_workers int

Number of parallel workers (default 1). When > 1, each worker (identified by shard_id) receives a disjoint partition of the globally-shuffled elements. Worker k gets elements at global positions [k::num_workers].

1

shuffle class-attribute instance-attribute ¤

shuffle: bool = False

cache_size class-attribute instance-attribute ¤

cache_size: int = 0

prefetch_size class-attribute instance-attribute ¤

prefetch_size: int = 0

track_metadata class-attribute instance-attribute ¤

track_metadata: bool = False

shard_id class-attribute instance-attribute ¤

shard_id: int | None = None

num_workers class-attribute instance-attribute ¤

num_workers: int = 1

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

stochastic class-attribute instance-attribute ¤

stochastic: bool = False

stream_name class-attribute instance-attribute ¤

stream_name: str | None = None

MemorySource ¤

MemorySource(config: MemorySourceConfig, data: dict[str, Any] | list[Any] | Sequence[Any], *, rngs: Rngs | None = None, name: str | None = None)

Bases: DataSourceModule

In-memory data source for Datarax.

This data source serves data from in-memory collections and supports both stateless and stateful operation modes.

Key Features:

- Dual-mode operation (stateless iteration and stateful with internal index)
- Random access via __getitem__
- Optional shuffling with RNG support
- Batch retrieval with get_batch method
- Support for dictionary and list/sequence data
- Batch-first design for efficient processing

Examples:

Create source with list data:

# Create source with list data
data = [{'x': i, 'y': i*2} for i in range(100)]
config = MemorySourceConfig(shuffle=False)
source = MemorySource(config, data, rngs=nnx.Rngs(0))

# Stateless iteration
for item in source:
    process(item)

# Stateful iteration with internal index
batch = source.get_batch(32)  # Gets next 32 items

# With shuffling
config = MemorySourceConfig(shuffle=True)
source = MemorySource(config, data, rngs=nnx.Rngs(0))

Parameters:

Name Type Description Default
config MemorySourceConfig

Configuration for the MemorySource

required
data dict[str, Any] | list[Any] | Sequence[Any]

Either a dictionary mapping keys to data arrays or a list/sequence of elements. If a dictionary is provided, all values must have the same first dimension size.

required
rngs Rngs | None

Optional RNG state for shuffling and stateful iteration

None
name str | None

Optional name for the module (defaults to "MemorySource")

None

Raises:

Type Description
ValueError

If dictionary data has inconsistent lengths

TypeError

If data is a string

config instance-attribute ¤

data instance-attribute ¤

data: dict[str, Any] | list[Any] | Sequence[Any] = data(data)

prefetch_size instance-attribute ¤

prefetch_size = prefetch_size

length instance-attribute ¤

length = lengths[0]

index instance-attribute ¤

index = Variable(0)

epoch instance-attribute ¤

epoch = Variable(0)

metadata_manager instance-attribute ¤

metadata_manager = MetadataManager(rngs=rngs, track_batches=True, shard_id=shard_id)

is_random_order property ¤

is_random_order: bool

Whether this source randomizes iteration order.

has_metadata property ¤

has_metadata: bool

Check if this source is tracking metadata.

Returns:

Type Description
bool

True if metadata tracking is enabled, False otherwise

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = stochastic

stream_name instance-attribute ¤

stream_name = stream_name

get_batch ¤

get_batch(batch_size: int, key: Array | None = None) -> Any

Get next batch of data.

This method supports both stateless (with explicit key) and stateful (with internal index tracking) operation.

Parameters:

Name Type Description Default
batch_size int

Number of elements in the batch

required
key Array | None

Optional RNG key for shuffling (stateless mode)

None

Returns:

Type Description
Any

Batch of data with shape (batch_size, ...)

get_batch_at ¤

get_batch_at(start: int | Array, size: int, key: Array | None = None) -> Any

Stateless indexed batch access; JIT-traceable for scan-based iteration.

Returns size records starting at logical position start. Does not advance self.index or any other internal state, so callers can drive iteration via their own nnx.Variable position counter and trace get_batch_at under nnx.scan / nnx.jit.

Two modes:

  • Sequential (MemorySourceConfig(shuffle=False)): returns the contiguous slice data[start : start + size] with wrap-around at the end of the source.
  • Shuffled (MemorySourceConfig(shuffle=True)): applies a deterministic permutation derived from key and returns the slice of that permutation. Same (start, size, key) always returns the same output; different key yields a different permutation. The permutation is materialized via jax.random.permutation(key, length) per call — O(length) per batch. Future optimization: switch to a Feistel-network PRP for O(1) per-element shuffled lookup at large dataset sizes.

Parameters:

Name Type Description Default
start int | Array

Starting logical index (inclusive); accepts concrete int or traced jax.Array.

required
size int

Number of records to return (must be a Python int — JAX shapes are static).

required
key Array | None

PRNG key for shuffled mode. Required when the source is configured with shuffle=True; ignored otherwise.

None

Returns:

Type Description
Any

Batch dict with leading dim size.

reset ¤

reset(seed: int | None = None) -> None

Reset the source to the beginning.

Parameters:

Name Type Description Default
seed int | None

Optional seed for reproducibility (ignored for MemorySource)

None

set_random_order ¤

set_random_order(enabled: bool) -> None

Enable or disable random-order iteration.

Parameters:

Name Type Description Default
enabled bool

Whether to randomize iteration order.

required

get_with_metadata ¤

get_with_metadata(index: int) -> tuple[Any, RecordMetadata]

Get element at specific index with its metadata.

This method is only available when track_metadata=True was set during initialization.

Parameters:

Name Type Description Default
index int

Index of element to retrieve

required

Returns:

Type Description
tuple[Any, RecordMetadata]

Tuple of (data_element, metadata)

Raises:

Type Description
RuntimeError

If metadata tracking is not enabled

IndexError

If index is out of bounds

get_batch_with_metadata ¤

get_batch_with_metadata(batch_size: int, key: Array | None = None) -> tuple[Any, list[RecordMetadata]]

Get next batch of data with metadata for each element.

This method is only available when track_metadata=True was set during initialization.

Parameters:

Name Type Description Default
batch_size int

Number of elements in the batch

required
key Array | None

Optional RNG key for shuffling (stateless mode)

None

Returns:

Type Description
tuple[Any, list[RecordMetadata]]

Tuple of (batch_data, list_of_metadata)

Raises:

Type Description
RuntimeError

If metadata tracking is not enabled

element_spec ¤

element_spec() -> Any

Return per-element shape/dtype derived from the in-memory data.

Dict-mode sources strip the leading dataset-size dimension from every array to produce one jax.ShapeDtypeStruct per key. List-mode sources introspect element 0 and apply jax.tree.map to produce a matching PyTree of ShapeDtypeStruct leaves.

Returns:

Type Description
Any

jax.ShapeDtypeStruct PyTree describing one emitted element.

Raises:

Type Description
ValueError

If the source is empty (no element to introspect).

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

process ¤

process(input: Any, *args: Any, **kwargs: Any) -> Any

Process input structure.

This method transforms the structure/organization of input data without modifying the data values themselves.

Subclasses MUST implement this method.

The input/output types depend on the specific structural processor:

  • Batcher: list[Element] -> list[Batch]
  • Sampler: int -> list[int]
  • Sharder: Batch -> Sharded[Batch]
  • Splitter: Dataset -> tuple[Dataset, Dataset]

Parameters:

Name Type Description Default
input Any

Input to process (type varies by processor)

required
*args Any

Additional positional arguments (processor-specific)

()
**kwargs Any

Additional keyword arguments (processor-specific)

{}

Returns:

Type Description
Any

Processed output (type varies by processor)

Examples:

Batcher implementation:

def process(self, elements: list[Element]) -> list[Batch]:
    batches = []
    for i in range(0, len(elements), self.config.batch_size):
        batch_elements = elements[i:i + self.config.batch_size]
        batches.append(Batch.from_elements(batch_elements))
    return batches

Sampler implementation (deterministic):

def process(self, dataset_size: int) -> list[int]:
    return list(range(min(self.config.num_samples, dataset_size)))

Sampler implementation (stochastic):

def process(self, dataset_size: int) -> list[int]:
    rng = self.rngs[self.config.stream_name]()
    indices = jax.random.choice(
        rng, dataset_size, shape=(self.config.num_samples,),
        replace=self.config.replacement
    )
    return indices.tolist()