Skip to content

Default Batcher¤

Standard batching with padding and remainder handling.

See Also¤


datarax.batching.default_batcher ¤

Default batcher module implementation for Datarax.

This module provides a default implementation of the BatcherModule interface that handles batching of PyTrees.

logger module-attribute ¤

logger = getLogger(__name__)

DefaultBatcherConfig dataclass ¤

DefaultBatcherConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None)

Bases: StructuralConfig

Configuration for DefaultBatcher.

DefaultBatcher is deterministic and requires no additional configuration beyond the base StructuralConfig.

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

stochastic class-attribute instance-attribute ¤

stochastic: bool = False

stream_name class-attribute instance-attribute ¤

stream_name: str | None = None

DefaultBatcher ¤

DefaultBatcher(config: DefaultBatcherConfig, *, collate_fn: Callable[[list[Element]], Batch] | None = None, rngs: Rngs | None = None, name: str | None = None)

Bases: BatcherModule

Default implementation of the BatcherModule interface.

This batcher module accumulates individual data elements and forms batches by stacking arrays along a new leading dimension. It handles PyTrees of arbitrary structure, maintaining the same structure in the batched output.

Parameters:

Name Type Description Default
config DefaultBatcherConfig

Configuration for the batcher.

required
collate_fn Callable[[list[Element]], Batch] | None

Optional custom function to use for combining elements into a batch. If None, a default stacking approach is used.

None
rngs Rngs | None

Optional Rngs object for randomness.

None
name str | None

Optional name for the module.

None

collate_fn instance-attribute ¤

collate_fn: Callable[[list[Element]], Batch] | None = collate_fn

config instance-attribute ¤

config = static(config)

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = stochastic

stream_name instance-attribute ¤

stream_name = stream_name

process ¤

process(elements: Iterator[Element], *_args: Any, batch_size: int, drop_remainder: bool = False, **_kwargs: Any) -> Iterator[Batch]

Group individual data elements into batches.

Parameters:

Name Type Description Default
elements Iterator[Element]

An iterator yielding individual data elements.

required
*args

Additional positional arguments (ignored).

required
batch_size int

The number of elements to include in each batch.

required
drop_remainder bool

Whether to drop the last batch if it's smaller than batch_size.

False
**kwargs

Additional keyword arguments (ignored).

required

Yields:

Type Description
Batch

Batches of data elements.

Raises:

Type Description
ValueError

If batch_size is not positive.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

batch_spec ¤

batch_spec(element_spec: Any, *, batch_size: int) -> dict[str, Any]

Return the batched-output spec given a per-element spec and batch_size.

The default implementation prepends a leading (batch_size,) dimension to every jax.ShapeDtypeStruct leaf of element_spec and adds a top-level valid_mask leaf of shape (batch_size,) and dtype bool. The mask flags valid positions so end-of-epoch padding does not contribute to mask-weighted loss.

Subclasses (e.g., MultiRateBatcher) override only when the batch layout requires more than a simple leading-dim prepend.

Parameters:

Name Type Description Default
element_spec Any

PyTree of jax.ShapeDtypeStruct describing one element (typically the output of the upstream operator's output_spec or the source's element_spec).

required
batch_size int

Number of elements per emitted batch.

required

Returns:

Type Description
dict[str, Any]

A dict containing the batched element spec under the original keys

dict[str, Any]

plus a "valid_mask" key of shape (batch_size,) and bool dtype.

Raises:

Type Description
ValueError

If batch_size is not positive.