Skip to content

Mixed Source¤

Combine multiple data sources with configurable mixing.

See Also¤


datarax.sources.mixed_source ¤

Mixed data source implementation for Datarax.

This module provides a data source that mixes elements from multiple child sources according to configurable weights. Useful for combining heterogeneous data streams (e.g., different image datasets, synthetic + real data).

logger module-attribute ¤

logger = getLogger(__name__)

MixDataSourcesConfig dataclass ¤

MixDataSourcesConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, num_sources: int | None = None, weights: tuple[float, ...] | None = None)

Bases: StructuralConfig

Configuration for MixDataSourcesNode.

Attributes:

Name Type Description
num_sources int | None

Number of child sources (validated against actual sources)

weights tuple[float, ...] | None

Sampling weights per source (normalized automatically)

num_sources class-attribute instance-attribute ¤

num_sources: int | None = None

weights class-attribute instance-attribute ¤

weights: tuple[float, ...] | None = None

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

stochastic class-attribute instance-attribute ¤

stochastic: bool = False

stream_name class-attribute instance-attribute ¤

stream_name: str | None = None

MixDataSourcesNode ¤

MixDataSourcesNode(config: MixDataSourcesConfig, sources: list[DataSourceModule], *, rngs: Rngs | None = None, name: str | None = None)

Bases: DataSourceModule

Mix multiple data sources with configurable weights.

Sampling strategy is delegated to Grain's weighted IterDataset.mix.

Total elements = sum of all source lengths.

Parameters:

Name Type Description Default
config MixDataSourcesConfig

Configuration with num_sources and weights.

required
sources list[DataSourceModule]

List of data source modules to mix from.

required
rngs Rngs | None

Optional Flax NNX random number generators.

None
name str | None

Optional module name for identification.

None

index instance-attribute ¤

index = Variable(0)

epoch instance-attribute ¤

epoch = Variable(0)

config instance-attribute ¤

config = static(config)

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = stochastic

stream_name instance-attribute ¤

stream_name = stream_name

to_grain_iter_dataset ¤

to_grain_iter_dataset() -> IterDataset

Return the Grain mixed streaming dataset backing this source.

reset ¤

reset() -> None

Reset all internal state and child sources to initial conditions.

get_batch_at ¤

get_batch_at(start: int | Array, size: int, key: Array | None = None) -> dict[str, Array]

Stateless weighted-interleave batch access for Pipeline-driven iteration.

Each output position deterministically chooses a source via weighted categorical sampling derived from key and the absolute position, picks a local index uniformly within that source, and dispatches to that source's own get_batch_at.

Algorithm (per output position p):

  1. pos_key = jax.random.fold_in(key, start + p) — deterministic.
  2. Split pos_key into (src_key, idx_key, fetch_key).
  3. chosen_src = jax.random.categorical(src_key, log_weights).
  4. local_idx = jax.random.randint(idx_key, 0, len(sources[chosen_src])).
  5. record = lax.switch(chosen_src, [s.get_batch_at(li, 1, fk) for s in sources]).

The same (start, size, key) always returns the same output — no internal counters are mutated. vmap over positions builds the full batch in one trace.

Parameters:

Name Type Description Default
start int | Array

Starting logical position (int or traced jax.Array).

required
size int

Number of records to return.

required
key Array | None

PRNG key for deterministic source / index selection. Required — mixing without a key has no defined semantics.

None

Returns:

Type Description
dict[str, Array]

Dict mapping each data key to a JAX array with leading dim

dict[str, Array]

size, drawn from the underlying sources in proportion

dict[str, Array]

to self._weights.

Raises:

Type Description
ValueError

If key is None.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

process ¤

process(input: Any, *args: Any, **kwargs: Any) -> Any

Process input structure.

This method transforms the structure/organization of input data without modifying the data values themselves.

Subclasses MUST implement this method.

The input/output types depend on the specific structural processor:

  • Batcher: list[Element] -> list[Batch]
  • Sampler: int -> list[int]
  • Sharder: Batch -> Sharded[Batch]
  • Splitter: Dataset -> tuple[Dataset, Dataset]

Parameters:

Name Type Description Default
input Any

Input to process (type varies by processor)

required
*args Any

Additional positional arguments (processor-specific)

()
**kwargs Any

Additional keyword arguments (processor-specific)

{}

Returns:

Type Description
Any

Processed output (type varies by processor)

Examples:

Batcher implementation:

def process(self, elements: list[Element]) -> list[Batch]:
    batches = []
    for i in range(0, len(elements), self.config.batch_size):
        batch_elements = elements[i:i + self.config.batch_size]
        batches.append(Batch.from_elements(batch_elements))
    return batches

Sampler implementation (deterministic):

def process(self, dataset_size: int) -> list[int]:
    return list(range(min(self.config.num_samples, dataset_size)))

Sampler implementation (stochastic):

def process(self, dataset_size: int) -> list[int]:
    rng = self.rngs[self.config.stream_name]()
    indices = jax.random.choice(
        rng, dataset_size, shape=(self.config.num_samples,),
        replace=self.config.replacement
    )
    return indices.tolist()

element_spec ¤

element_spec() -> Any

Return a PyTree of jax.ShapeDtypeStruct describing per-element output.

Downstream consumers (operators, batchers, models) use this contract to pre-allocate buffers, auto-size learnable layers, and statically validate operator chains. Subclasses MUST override this method.

Returns:

Type Description
Any

A PyTree (typically a dict) whose leaves are jax.ShapeDtypeStruct

Any

instances describing one emitted element.

Raises:

Type Description
NotImplementedError

Always, on the base class. Subclasses must override.