Batch Mix Operator¤
Mix samples within batches (e.g., MixUp, CutMix augmentation).
See Also¤
- Operators Overview - All operator types
- Batching - Batching utilities
- Probabilistic Operator - Random augmentation
- NNX Best Practices - Training techniques
datarax.operators.batch_mix_operator ¤
BatchMixOperator - MixUp and CutMix batch augmentation.
This module provides BatchMixOperator, which performs batch-level sample mixing that cannot be decomposed into element-level operations.
Key Difference from Other Operators:
- Standard operators use vmap to process elements independently
- BatchMixOperator overrides apply_batch() to access full batch
- Mixing requires cross-element access (sample A mixed with sample B)
Supported Modes:
- mixup: Linear interpolation between pairs of samples
- cutmix: Cut and paste rectangular patches between images
Key Features:
- Unified API for both MixUp and CutMix
- Beta distribution for mixing ratio (alpha parameter)
- Optional label mixing (proportional to mixed area)
- Full JAX compatibility (JIT, grad)
BatchMixOperator ¤
BatchMixOperator(config: BatchMixOperatorConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: OperatorModule
Unified operator for batch-level MixUp and CutMix augmentation.
Performs batch-level sample mixing that requires access to multiple samples simultaneously. This operator overrides apply_batch() to work at the batch level instead of using vmap.
Modes:
MixUp Mode
Creates virtual training examples by linear interpolation: x_mixed = λ * x_a + (1 - λ) * x_b where λ ~ Beta(α, α)
CutMix Mode
Cuts rectangular patches and pastes between images: x_mixed = mask * x_a + (1 - mask) * x_b Labels are mixed proportionally to the cut area.
Examples:
config = BatchMixOperatorConfig(mode="mixup", alpha=0.4) # MixUp augmentation
op = BatchMixOperator(config, rngs=rngs)
mixed_batch = op(batch)
config = BatchMixOperatorConfig(mode="cutmix", alpha=1.0) # CutMix augmentation
op = BatchMixOperator(config, rngs=rngs)
mixed_batch = op(batch)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
BatchMixOperatorConfig
|
Operator configuration (mode, alpha, field names) |
required |
rngs
|
Rngs | None
|
Random number generators (required - always stochastic) |
None
|
name
|
str | None
|
Optional name for the operator |
None
|
generate_random_params ¤
Generate random parameters - not used for batch-level ops.
BatchMixOperator overrides apply_batch() completely, so this method is not called. Implemented to satisfy the interface.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Array
|
JAX random key |
required |
data_shapes
|
PyTree
|
PyTree with shapes |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The input rng unchanged (not used) |
apply ¤
apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Any = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply operator to single element - not used for batch-level ops.
BatchMixOperator overrides apply_batch() completely, so this method is not called. Batch mixing cannot be decomposed into element-level operations. Implemented to satisfy the interface.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Element data PyTree |
required |
state
|
PyTree
|
Element state PyTree |
required |
metadata
|
dict[str, Any] | None
|
Element metadata |
required |
random_params
|
Any
|
Unused |
None
|
stats
|
dict[str, Any] | None
|
Unused |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Input unchanged (not used in practice) |
apply_batch ¤
Apply batch-level mixing augmentation.
This method overrides the base class to work at batch level instead of using vmap. Batch mixing requires cross-element access that cannot be expressed with vmap.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Input batch to mix |
required |
stats
|
dict[str, Any] | None
|
Optional statistics (unused) |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
Mixed batch with same structure |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
get_output_structure ¤
get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]
Declare output PyTree structure for vmap axis specification.
Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_data
|
PyTree
|
Single element data (not batched) |
required |
sample_state
|
PyTree
|
Single element state (not batched) |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (output_data_structure, output_state_structure) with None leaves. |
PyTree
|
The structure (keys/nesting) matters, leaf values are ignored. |
Example override for operator that adds keys
def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state
output_spec ¤
Return the operator's output spec given an input spec.
Most operators (normalization, additive noise, simple element-wise
transforms) do not change shape; the default returns input_spec
unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST
override this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_spec
|
PyTree
|
PyTree of |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree of |
PyTree
|
By default, equal to |