Default Batcher¤
Standard batching with padding and remainder handling.
See Also¤
- Batching Overview - Batching concepts
- Core Batcher - Batcher protocol
- DAG Executor - Pipeline batching
- Simple Pipeline Example
datarax.batching.default_batcher ¤
Default batcher module implementation for Datarax.
This module provides a default implementation of the BatcherModule interface that handles batching of PyTrees.
DefaultBatcherConfig
dataclass
¤
DefaultBatcherConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None)
Bases: StructuralConfig
Configuration for DefaultBatcher.
DefaultBatcher is deterministic and requires no additional configuration beyond the base StructuralConfig.
precomputed_stats
class-attribute
instance-attribute
¤
DefaultBatcher ¤
DefaultBatcher(config: DefaultBatcherConfig, *, collate_fn: Callable[[list[Element]], Batch] | None = None, rngs: Rngs | None = None, name: str | None = None)
Bases: BatcherModule
Default implementation of the BatcherModule interface.
This batcher module accumulates individual data elements and forms batches by stacking arrays along a new leading dimension. It handles PyTrees of arbitrary structure, maintaining the same structure in the batched output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DefaultBatcherConfig
|
Configuration for the batcher. |
required |
collate_fn
|
Callable[[list[Element]], Batch] | None
|
Optional custom function to use for combining elements into a batch. If None, a default stacking approach is used. |
None
|
rngs
|
Rngs | None
|
Optional Rngs object for randomness. |
None
|
name
|
str | None
|
Optional name for the module. |
None
|
process ¤
process(elements: Iterator[Element], *_args: Any, batch_size: int, drop_remainder: bool = False, **_kwargs: Any) -> Iterator[Batch]
Group individual data elements into batches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
elements
|
Iterator[Element]
|
An iterator yielding individual data elements. |
required |
*args
|
Additional positional arguments (ignored). |
required | |
batch_size
|
int
|
The number of elements to include in each batch. |
required |
drop_remainder
|
bool
|
Whether to drop the last batch if it's smaller than batch_size. |
False
|
**kwargs
|
Additional keyword arguments (ignored). |
required |
Yields:
| Type | Description |
|---|---|
Batch
|
Batches of data elements. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If batch_size is not positive. |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
batch_spec ¤
Return the batched-output spec given a per-element spec and batch_size.
The default implementation prepends a leading (batch_size,) dimension
to every jax.ShapeDtypeStruct leaf of element_spec and adds a
top-level valid_mask leaf of shape (batch_size,) and dtype bool.
The mask flags valid positions so end-of-epoch padding does not
contribute to mask-weighted loss.
Subclasses (e.g., MultiRateBatcher) override only when the batch
layout requires more than a simple leading-dim prepend.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
element_spec
|
Any
|
PyTree of |
required |
batch_size
|
int
|
Number of elements per emitted batch. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dict containing the batched element spec under the original keys |
dict[str, Any]
|
plus a |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |