Batcher Protocol¤
Core protocol for batch creation.
See Also¤
- Core Overview - All core protocols
- Batching - Batcher implementations
- DAG Executor - Batching in pipelines
- Simple Pipeline Example
datarax.core.batcher ¤
Base module for batcher components in Datarax.
This module defines the base class for all Datarax batcher components that use flax.nnx.Module for state management and JAX transformation compatibility.
BatcherModule ¤
BatcherModule(config: StructuralConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: StructuralModule
Base module for all Datarax batcher components.
A BatcherModule is responsible for grouping individual data elements into batches. It handles the accumulation and collation of elements, maintaining the PyTree structure in the batched output.
This class extends StructuralModule for non-parametric structural processing. Subclasses implement the process() method for batching logic.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
StructuralConfig
|
StructuralConfig or subclass with batcher-specific parameters |
required |
rngs
|
Rngs | None
|
Random number generators (required if stochastic=True) |
None
|
name
|
str | None
|
Optional name for the batcher |
None
|
Examples:
Basic Batcher implementation:
from dataclasses import dataclass
from datarax.core.config import StructuralConfig
from datarax.core.batcher import BatcherModule
from flax import nnx
class DefaultBatcherConfig(StructuralConfig):
pass
class DefaultBatcher(BatcherModule):
def process(self, elements, batch_size, drop_remainder=False):
# In a real implementation, this would yield actual batches
return []
config = DefaultBatcherConfig(stochastic=False)
batcher = DefaultBatcher(config, rngs=nnx.Rngs(0))
batches = list(batcher([], batch_size=32)) # Call the batcher instance
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
StructuralConfig
|
Structural module configuration (already validated, frozen) |
required |
rngs
|
Rngs | None
|
Random number generators (required if stochastic=True) |
None
|
name
|
str | None
|
Optional module name |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If stochastic=True but rngs is None |
process ¤
process(elements: list[Element] | Iterator[Element], *args: Any, batch_size: int, drop_remainder: bool = False, **kwargs: Any) -> list[Batch] | Iterator[Batch]
Group individual data elements into batches.
This is the main interface for batching operations. Subclasses MUST override this method to implement their batching logic.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
elements
|
list[Element] | Iterator[Element]
|
An iterator or list yielding individual data elements. |
required |
*args
|
Any
|
Additional positional arguments (processor-specific). |
()
|
batch_size
|
int
|
The number of elements to include in each batch. |
required |
drop_remainder
|
bool
|
Whether to drop the last batch if it's smaller than batch_size. |
False
|
**kwargs
|
Any
|
Additional keyword arguments (processor-specific). |
{}
|
Returns:
| Type | Description |
|---|---|
list[Batch] | Iterator[Batch]
|
An iterator or list that yields batches of data elements. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If batch_size is not positive. |
batch_spec ¤
Return the batched-output spec given a per-element spec and batch_size.
The default implementation prepends a leading (batch_size,) dimension
to every jax.ShapeDtypeStruct leaf of element_spec and adds a
top-level valid_mask leaf of shape (batch_size,) and dtype bool.
The mask flags valid positions so end-of-epoch padding does not
contribute to mask-weighted loss.
Subclasses (e.g., MultiRateBatcher) override only when the batch
layout requires more than a simple leading-dim prepend.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
element_spec
|
Any
|
PyTree of |
required |
batch_size
|
int
|
Number of elements per emitted batch. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dict containing the batched element spec under the original keys |
dict[str, Any]
|
plus a |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |