Skip to content

Batch Mix Operator¤

Mix samples within batches (e.g., MixUp, CutMix augmentation).

See Also¤


datarax.operators.batch_mix_operator ¤

BatchMixOperator - MixUp and CutMix batch augmentation.

This module provides BatchMixOperator, which performs batch-level sample mixing that cannot be decomposed into element-level operations.

Key Difference from Other Operators:

  • Standard operators use vmap to process elements independently
  • BatchMixOperator overrides apply_batch() to access full batch
  • Mixing requires cross-element access (sample A mixed with sample B)

Supported Modes:

  • mixup: Linear interpolation between pairs of samples
  • cutmix: Cut and paste rectangular patches between images

Key Features:

  • Unified API for both MixUp and CutMix
  • Beta distribution for mixing ratio (alpha parameter)
  • Optional label mixing (proportional to mixed area)
  • Full JAX compatibility (JIT, grad)

logger module-attribute ¤

logger = getLogger(__name__)

BatchMixOperator ¤

BatchMixOperator(config: BatchMixOperatorConfig, *, rngs: Rngs | None = None, name: str | None = None)

Bases: OperatorModule

Unified operator for batch-level MixUp and CutMix augmentation.

Performs batch-level sample mixing that requires access to multiple samples simultaneously. This operator overrides apply_batch() to work at the batch level instead of using vmap.

Modes:

MixUp Mode

Creates virtual training examples by linear interpolation: x_mixed = λ * x_a + (1 - λ) * x_b where λ ~ Beta(α, α)

CutMix Mode

Cuts rectangular patches and pastes between images: x_mixed = mask * x_a + (1 - mask) * x_b Labels are mixed proportionally to the cut area.

Examples:

config = BatchMixOperatorConfig(mode="mixup", alpha=0.4)  # MixUp augmentation
op = BatchMixOperator(config, rngs=rngs)
mixed_batch = op(batch)
config = BatchMixOperatorConfig(mode="cutmix", alpha=1.0)  # CutMix augmentation
op = BatchMixOperator(config, rngs=rngs)
mixed_batch = op(batch)

Parameters:

Name Type Description Default
config BatchMixOperatorConfig

Operator configuration (mode, alpha, field names)

required
rngs Rngs | None

Random number generators (required - always stochastic)

None
name str | None

Optional name for the operator

None

config instance-attribute ¤

config: BatchMixOperatorConfig = config

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = static(stochastic)

stream_name instance-attribute ¤

stream_name = static(stream_name)

generate_random_params ¤

generate_random_params(rng: Array, data_shapes: PyTree) -> Array

Generate random parameters - not used for batch-level ops.

BatchMixOperator overrides apply_batch() completely, so this method is not called. Implemented to satisfy the interface.

Parameters:

Name Type Description Default
rng Array

JAX random key

required
data_shapes PyTree

PyTree with shapes

required

Returns:

Type Description
Array

The input rng unchanged (not used)

apply ¤

apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Any = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply operator to single element - not used for batch-level ops.

BatchMixOperator overrides apply_batch() completely, so this method is not called. Batch mixing cannot be decomposed into element-level operations. Implemented to satisfy the interface.

Parameters:

Name Type Description Default
data PyTree

Element data PyTree

required
state PyTree

Element state PyTree

required
metadata dict[str, Any] | None

Element metadata

required
random_params Any

Unused

None
stats dict[str, Any] | None

Unused

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Input unchanged (not used in practice)

apply_batch ¤

apply_batch(batch: Batch, stats: dict[str, Any] | None = None) -> Batch

Apply batch-level mixing augmentation.

This method overrides the base class to work at batch level instead of using vmap. Batch mixing requires cross-element access that cannot be expressed with vmap.

Parameters:

Name Type Description Default
batch Batch

Input batch to mix

required
stats dict[str, Any] | None

Optional statistics (unused)

None

Returns:

Type Description
Batch

Mixed batch with same structure

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

get_output_structure ¤

get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]

Declare output PyTree structure for vmap axis specification.

Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).

Parameters:

Name Type Description Default
sample_data PyTree

Single element data (not batched)

required
sample_state PyTree

Single element state (not batched)

required

Returns:

Type Description
PyTree

Tuple of (output_data_structure, output_state_structure) with None leaves.

PyTree

The structure (keys/nesting) matters, leaf values are ignored.

Example override for operator that adds keys

def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state

output_spec ¤

output_spec(input_spec: PyTree) -> PyTree

Return the operator's output spec given an input spec.

Most operators (normalization, additive noise, simple element-wise transforms) do not change shape; the default returns input_spec unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST override this method.

Parameters:

Name Type Description Default
input_spec PyTree

PyTree of jax.ShapeDtypeStruct describing the input element (matching the upstream DataSourceModule.element_spec() or another operator's output_spec).

required

Returns:

Type Description
PyTree

PyTree of jax.ShapeDtypeStruct describing the operator's output.

PyTree

By default, equal to input_spec.