Skip to content

Composite Operator¤

The CompositeOperatorModule enables composing multiple operators into sophisticated pipelines using 11 different composition strategies. It's the foundation for building complex data augmentation and transformation workflows.

Composition Strategies¤

Strategy Description
Sequential Chain operators: output of one → input of next
Conditional Sequential Chain with per-operator conditions
Dynamic Sequential Runtime-modifiable chain
Parallel Apply all operators to same input, merge outputs
Weighted Parallel Parallel with learnable weights
Conditional Parallel Parallel with per-operator conditions
Ensemble Mean/Sum/Max/Min Parallel + reduction
Branching Route through different paths based on input

★ Insight ─────────────────────────────────────

  • CompositeOperator uses JAX-compatible patterns throughout
  • Integer-based branching with jax.lax.switch (not dict lookups)
  • Fixed-shape outputs for vmap compatibility
  • All strategies work inside jax.jit and jax.vmap

─────────────────────────────────────────────────

Quick Start¤

Sequential Composition¤

Chain operators where each output feeds into the next:

from datarax.operators import CompositeOperatorModule
from datarax.operators.composite_operator import (
    CompositeOperatorConfig,
    CompositionStrategy,
)

# Create child operators
normalize = create_normalize_op()
augment = create_augment_op()

config = CompositeOperatorConfig(
    strategy=CompositionStrategy.SEQUENTIAL,
    operators=[normalize, augment],
)
pipeline = CompositeOperatorModule(config)

Parallel Composition¤

Apply multiple operators to the same input and merge results:

config = CompositeOperatorConfig(
    strategy=CompositionStrategy.PARALLEL,
    operators=[op_a, op_b, op_c],
    merge_strategy="concat",  # or "stack", "sum", "mean", "dict"
    merge_axis=-1,
)
parallel_op = CompositeOperatorModule(config)

Ensemble with Reduction¤

Combine multiple model outputs with reduction:

config = CompositeOperatorConfig(
    strategy=CompositionStrategy.ENSEMBLE_MEAN,
    operators=[model_a, model_b, model_c],
)
ensemble = CompositeOperatorModule(config)
# Output is element-wise mean of all operator outputs

Conditional Branching¤

Route data through different paths based on conditions:

def router(data):
    """Return integer index of operator to use."""
    # Must return int or JAX scalar (not strings!)
    return 0 if data["type"] == "image" else 1

config = CompositeOperatorConfig(
    strategy=CompositionStrategy.BRANCHING,
    operators=[image_processor, text_processor],
    router=router,
)
branched = CompositeOperatorModule(config)

Weighted Parallel (Learnable)¤

Create learnable weighted combinations:

config = CompositeOperatorConfig(
    strategy=CompositionStrategy.WEIGHTED_PARALLEL,
    operators=[op_a, op_b],
    weights=[0.5, 0.5],
    learnable_weights=True,  # Weights become trainable parameters
)
weighted = CompositeOperatorModule(config, rngs=nnx.Rngs(0))

# Access weights for training
current_weights = weighted.weights.get_value()

Dynamic Sequential¤

Modify the operator chain at runtime:

config = CompositeOperatorConfig(
    strategy=CompositionStrategy.DYNAMIC_SEQUENTIAL,
    operators=[op_a, op_b],
)
dynamic = CompositeOperatorModule(config)

# Modify at runtime
dynamic.add_operator(op_c)
dynamic.remove_operator(1)
dynamic.reorder_operators([1, 0, 2])

JAX Compatibility Notes¤

Important for JIT/vmap

  • Router functions must return integers, not strings
  • All code paths must return the same PyTree structure
  • Conditions should use jax.lax.cond, not Python if
# ✅ Correct: Integer-based routing
def router(x): return 0 if condition else 1

# ❌ Wrong: String-based routing (breaks tracing)
def router(x): return "path_a" if condition else "path_b"

See Also¤


API Reference¤

datarax.operators.composite_operator ¤

Unified composite operator module.

Implements CompositeOperatorModule with 11 composition strategies:

  • Sequential (3 variants): Chain operators
  • Parallel (3 variants): Apply all to same input
  • Ensemble (4 reductions): Parallel with mean/sum/max/min
  • Branching (1 routing): Route through different paths

WEIGHTED_PARALLEL supports three mutually exclusive weight modes:

  • Static weights: Fixed at construction via weights=[0.5, 0.5]
  • Learnable weights: Stored as nnx.Param via learnable_weights=True
  • Dynamic external weights: Extracted from data[weight_key] at each call via weight_key="op_weights", enabling upstream modules (e.g., Gumbel-Softmax policies) to supply per-call weights with full gradient flow

JAX vmap/JIT Compatibility Patterns¤

This module implements several critical patterns for vmap and JIT compatibility:

  1. Integer-Based Branching:

  2. Branching uses jax.lax.switch with integer indices, not dict lookups

  3. Router functions must return integers (0, 1, 2, ...), not strings
  4. Why: Traced JAX values cannot be used as dict keys or in Python if statements
  5. Pattern: jax.lax.switch(index, [fn0, fn1, fn2], operands)

  6. Fixed-Shape Conditional Outputs:

  7. Conditional strategies include ALL operator outputs (even False conditions)

  8. False-condition operators return identity via jax.lax.cond noop function
  9. Why: vmap requires all code paths to return the same PyTree structure
  10. Pattern: No dynamic filtering, use masking in merge instead

  11. PyTree Structure Preservation in Dict Merge:

  12. Dict merge returns {key: {op_0: val, op_1: val}} not {op_0: {key: val}}

  13. Why: Preserves input PyTree structure for vmap out_axes specification
  14. Pattern: Use jax.tree.map() to transform leaves into operator dicts

  15. Static Branching with jax.lax.cond:

  16. Conditional execution uses jax.lax.cond(condition, true_fn, false_fn, operands)

  17. Why: Python if statements break tracing, jax.lax.cond is trace-compatible
  18. Pattern: Define apply_fn and noop_fn, use jax.lax.cond for selection

  19. weight_key Data Stripping:

  20. When weight_key is set, it is stripped from both data (in apply()) and data_shapes (in generate_random_params())

  21. Why: Children's random param trees must match the clean data they receive; a shape mismatch causes vmap failures
  22. Pattern: Dict comprehension {k: v for k, v in d.items() if k != weight_key}

These patterns ensure all strategies work correctly inside jax.vmap and jax.jit.

logger module-attribute ¤

logger = getLogger(__name__)

CompositionStrategy ¤

Bases: Enum

Strategy for composing multiple operators.

SEQUENTIAL class-attribute instance-attribute ¤

SEQUENTIAL = auto()

CONDITIONAL_SEQUENTIAL class-attribute instance-attribute ¤

CONDITIONAL_SEQUENTIAL = auto()

DYNAMIC_SEQUENTIAL class-attribute instance-attribute ¤

DYNAMIC_SEQUENTIAL = auto()

PARALLEL class-attribute instance-attribute ¤

PARALLEL = auto()

WEIGHTED_PARALLEL class-attribute instance-attribute ¤

WEIGHTED_PARALLEL = auto()

CONDITIONAL_PARALLEL class-attribute instance-attribute ¤

CONDITIONAL_PARALLEL = auto()

ENSEMBLE_MEAN class-attribute instance-attribute ¤

ENSEMBLE_MEAN = auto()

ENSEMBLE_SUM class-attribute instance-attribute ¤

ENSEMBLE_SUM = auto()

ENSEMBLE_MAX class-attribute instance-attribute ¤

ENSEMBLE_MAX = auto()

ENSEMBLE_MIN class-attribute instance-attribute ¤

ENSEMBLE_MIN = auto()

BRANCHING class-attribute instance-attribute ¤

BRANCHING = auto()

CompositeOperatorConfig dataclass ¤

CompositeOperatorConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, batch_strategy: str = 'vmap', strategy: CompositionStrategy | None = None, operators: Sequence[OperatorModule] | None = None, merge_strategy: str | None = None, merge_fn: Callable | None = None, merge_axis: int = 0, weights: list[float] | None = None, learnable_weights: bool = False, weight_key: str | None = None, conditions: Sequence[Callable[[PyTree], bool | Array]] | None = None, router: Callable[[PyTree], int | Array] | None = None, default_branch: int | None = None)

Bases: OperatorConfig

Configuration for composite operators.

Inherits from OperatorConfig:

- name: str | None
- stochastic: bool (whether any child is stochastic)
- stream_name: str (for RNG if stochastic)

WEIGHTED_PARALLEL supports three mutually exclusive weight modes:

1. **Static weights** (default): ``weights=[0.5, 0.5]`` — fixed at construction.
2. **Learnable weights**: ``learnable_weights=True`` — stored as ``nnx.Param``,
   optimized via gradient descent.
3. **Dynamic external weights**: ``weight_key="op_weights"`` — extracted from
   ``data[weight_key]`` at each forward call. Enables upstream modules (e.g.,
   a Gumbel-Softmax policy) to supply weights that change per call, with
   gradients flowing back through the weights to the upstream parameters.

When weight_key is set, the key is stripped from the data dict before passing to child operators, so children only see the actual data fields.

Attributes:

Name Type Description
strategy CompositionStrategy | None

Composition strategy to use.

operators Sequence[OperatorModule] | None

List of operators for all strategies.

merge_strategy str | None

How to merge parallel outputs ("concat", "stack", "sum", "mean", "dict").

merge_fn Callable | None

Custom merge function (overrides merge_strategy).

merge_axis int

Axis for stack/concat operations.

weights list[float] | None

Weights for weighted parallel (None = equal weights).

learnable_weights bool

Whether weights are learnable parameters.

weight_key str | None

Key in data dict for external dynamic weights. Mutually exclusive with weights and learnable_weights. When set, weights are extracted from data[weight_key] at each call and the key is stripped from child data.

conditions Sequence[Callable[[PyTree], bool | Array]] | None

Conditions for conditional strategies (returns JAX arrays).

router Callable[[PyTree], int | Array] | None

Router function for branching (returns integer index).

default_branch int | None

Default branch index for fallback behavior.

strategy class-attribute instance-attribute ¤

strategy: CompositionStrategy | None = field(default=None)

operators class-attribute instance-attribute ¤

operators: Sequence[OperatorModule] | None = field(default=None)

merge_strategy class-attribute instance-attribute ¤

merge_strategy: str | None = None

merge_fn class-attribute instance-attribute ¤

merge_fn: Callable | None = None

merge_axis class-attribute instance-attribute ¤

merge_axis: int = 0

weights class-attribute instance-attribute ¤

weights: list[float] | None = None

learnable_weights class-attribute instance-attribute ¤

learnable_weights: bool = False

weight_key class-attribute instance-attribute ¤

weight_key: str | None = None

conditions class-attribute instance-attribute ¤

conditions: Sequence[Callable[[PyTree], bool | Array]] | None = None

router class-attribute instance-attribute ¤

router: Callable[[PyTree], int | Array] | None = None

default_branch class-attribute instance-attribute ¤

default_branch: int | None = None

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

stochastic class-attribute instance-attribute ¤

stochastic: bool = False

stream_name class-attribute instance-attribute ¤

stream_name: str | None = None

batch_strategy class-attribute instance-attribute ¤

batch_strategy: str = 'vmap'

CompositeOperatorModule ¤

CompositeOperatorModule(config: CompositeOperatorConfig, *, rngs: Rngs | None = None)

Bases: OperatorModule

Unified composite operator supporting all composition strategies.

Uses the Strategy Pattern internally — each CompositionStrategy enum value maps to a strategy implementation class (e.g., WeightedParallelStrategy).

For WEIGHTED_PARALLEL with weight_key, the composite extracts weights from the data dict at each forward call, strips the key from child data, and delegates to WeightedParallelStrategy for the weighted sum. This enables differentiable pipelines where an upstream module (e.g., Gumbel-Softmax policy) supplies per-call weights with full gradient flow.

Parameters:

Name Type Description Default
config CompositeOperatorConfig

Composite operator configuration

required
rngs Rngs | None

Optional RNGs for stochastic operators

None

config instance-attribute ¤

config: CompositeOperatorConfig = config

operators instance-attribute ¤

operators = Dict(operators)

weights instance-attribute ¤

weights = Param(array(weights))

operator_statistics instance-attribute ¤

operator_statistics = Variable({})

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = static(stochastic)

stream_name instance-attribute ¤

stream_name = static(stream_name)

generate_random_params ¤

generate_random_params(rng: Array, data_shapes: PyTree) -> dict[str, Any]

Generate random parameters for all child operators.

When weight_key is configured, strips that key from data_shapes before delegating to children. This ensures children's random param trees match the clean data they receive (without the weight key), preventing PyTree structure mismatches during vmap.

apply ¤

apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: dict[str, Any] | None = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply composition based on the configured strategy.

For WEIGHTED_PARALLEL with weight_key, extracts weights from data[weight_key], strips the key from data, and passes clean data to the strategy. Raises ValueError if the key is missing from data.

add_operator ¤

add_operator(operator: OperatorModule, index: int | None = None) -> None

Add operator to dynamic sequential.

remove_operator ¤

remove_operator(index: int) -> OperatorModule

Remove operator from dynamic sequential.

clear_operators ¤

clear_operators() -> None

Clear all operators from dynamic sequential.

reorder_operators ¤

reorder_operators(new_order: list[int]) -> None

Reorder operators in dynamic sequential.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

get_output_structure ¤

get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]

Declare output PyTree structure for vmap axis specification.

Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).

Parameters:

Name Type Description Default
sample_data PyTree

Single element data (not batched)

required
sample_state PyTree

Single element state (not batched)

required

Returns:

Type Description
PyTree

Tuple of (output_data_structure, output_state_structure) with None leaves.

PyTree

The structure (keys/nesting) matters, leaf values are ignored.

Example override for operator that adds keys

def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state

apply_batch ¤

apply_batch(batch: Batch, stats: dict[str, Any] | None = None) -> Batch

Process entire batch with vmap and optional RNG generation.

This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.

The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.

Parameters:

Name Type Description Default
batch Batch

Input batch (Batch[Element] structure)

required
stats dict[str, Any] | None

Optional statistics (if None, uses get_statistics())

None

Returns:

Type Description
Batch

Transformed batch with same structure

Note

This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.

output_spec ¤

output_spec(input_spec: PyTree) -> PyTree

Return the operator's output spec given an input spec.

Most operators (normalization, additive noise, simple element-wise transforms) do not change shape; the default returns input_spec unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST override this method.

Parameters:

Name Type Description Default
input_spec PyTree

PyTree of jax.ShapeDtypeStruct describing the input element (matching the upstream DataSourceModule.element_spec() or another operator's output_spec).

required

Returns:

Type Description
PyTree

PyTree of jax.ShapeDtypeStruct describing the operator's output.

PyTree

By default, equal to input_spec.