Composite Operator¤
The CompositeOperatorModule enables composing multiple operators into sophisticated pipelines using 11 different composition strategies. It's the foundation for building complex data augmentation and transformation workflows.
Composition Strategies¤
| Strategy | Description |
|---|---|
| Sequential | Chain operators: output of one → input of next |
| Conditional Sequential | Chain with per-operator conditions |
| Dynamic Sequential | Runtime-modifiable chain |
| Parallel | Apply all operators to same input, merge outputs |
| Weighted Parallel | Parallel with learnable weights |
| Conditional Parallel | Parallel with per-operator conditions |
| Ensemble Mean/Sum/Max/Min | Parallel + reduction |
| Branching | Route through different paths based on input |
★ Insight ─────────────────────────────────────
- CompositeOperator uses JAX-compatible patterns throughout
- Integer-based branching with
jax.lax.switch(not dict lookups) - Fixed-shape outputs for vmap compatibility
- All strategies work inside
jax.jitandjax.vmap
─────────────────────────────────────────────────
Quick Start¤
Sequential Composition¤
Chain operators where each output feeds into the next:
from datarax.operators import CompositeOperatorModule
from datarax.operators.composite_operator import (
CompositeOperatorConfig,
CompositionStrategy,
)
# Create child operators
normalize = create_normalize_op()
augment = create_augment_op()
config = CompositeOperatorConfig(
strategy=CompositionStrategy.SEQUENTIAL,
operators=[normalize, augment],
)
pipeline = CompositeOperatorModule(config)
Parallel Composition¤
Apply multiple operators to the same input and merge results:
config = CompositeOperatorConfig(
strategy=CompositionStrategy.PARALLEL,
operators=[op_a, op_b, op_c],
merge_strategy="concat", # or "stack", "sum", "mean", "dict"
merge_axis=-1,
)
parallel_op = CompositeOperatorModule(config)
Ensemble with Reduction¤
Combine multiple model outputs with reduction:
config = CompositeOperatorConfig(
strategy=CompositionStrategy.ENSEMBLE_MEAN,
operators=[model_a, model_b, model_c],
)
ensemble = CompositeOperatorModule(config)
# Output is element-wise mean of all operator outputs
Conditional Branching¤
Route data through different paths based on conditions:
def router(data):
"""Return integer index of operator to use."""
# Must return int or JAX scalar (not strings!)
return 0 if data["type"] == "image" else 1
config = CompositeOperatorConfig(
strategy=CompositionStrategy.BRANCHING,
operators=[image_processor, text_processor],
router=router,
)
branched = CompositeOperatorModule(config)
Weighted Parallel (Learnable)¤
Create learnable weighted combinations:
config = CompositeOperatorConfig(
strategy=CompositionStrategy.WEIGHTED_PARALLEL,
operators=[op_a, op_b],
weights=[0.5, 0.5],
learnable_weights=True, # Weights become trainable parameters
)
weighted = CompositeOperatorModule(config, rngs=nnx.Rngs(0))
# Access weights for training
current_weights = weighted.weights.get_value()
Dynamic Sequential¤
Modify the operator chain at runtime:
config = CompositeOperatorConfig(
strategy=CompositionStrategy.DYNAMIC_SEQUENTIAL,
operators=[op_a, op_b],
)
dynamic = CompositeOperatorModule(config)
# Modify at runtime
dynamic.add_operator(op_c)
dynamic.remove_operator(1)
dynamic.reorder_operators([1, 0, 2])
JAX Compatibility Notes¤
Important for JIT/vmap
- Router functions must return integers, not strings
- All code paths must return the same PyTree structure
- Conditions should use
jax.lax.cond, not Pythonif
# ✅ Correct: Integer-based routing
def router(x): return 0 if condition else 1
# ❌ Wrong: String-based routing (breaks tracing)
def router(x): return "path_a" if condition else "path_b"
See Also¤
- Element Operator - Single-element transformations
- Operator Strategies - Strategy implementations
- DAG Control Flow - DAG-level branching
- Operators Tutorial
API Reference¤
datarax.operators.composite_operator ¤
Unified composite operator module.
Implements CompositeOperatorModule with 11 composition strategies:
- Sequential (3 variants): Chain operators
- Parallel (3 variants): Apply all to same input
- Ensemble (4 reductions): Parallel with mean/sum/max/min
- Branching (1 routing): Route through different paths
WEIGHTED_PARALLEL supports three mutually exclusive weight modes:
- Static weights: Fixed at construction via
weights=[0.5, 0.5] - Learnable weights: Stored as
nnx.Paramvialearnable_weights=True - Dynamic external weights: Extracted from
data[weight_key]at each call viaweight_key="op_weights", enabling upstream modules (e.g., Gumbel-Softmax policies) to supply per-call weights with full gradient flow
JAX vmap/JIT Compatibility Patterns¤
This module implements several critical patterns for vmap and JIT compatibility:
-
Integer-Based Branching:
-
Branching uses
jax.lax.switchwith integer indices, not dict lookups - Router functions must return integers (0, 1, 2, ...), not strings
- Why: Traced JAX values cannot be used as dict keys or in Python if statements
-
Pattern:
jax.lax.switch(index, [fn0, fn1, fn2], operands) -
Fixed-Shape Conditional Outputs:
-
Conditional strategies include ALL operator outputs (even False conditions)
- False-condition operators return identity via
jax.lax.condnoop function - Why: vmap requires all code paths to return the same PyTree structure
-
Pattern: No dynamic filtering, use masking in merge instead
-
PyTree Structure Preservation in Dict Merge:
-
Dict merge returns
{key: {op_0: val, op_1: val}}not{op_0: {key: val}} - Why: Preserves input PyTree structure for vmap out_axes specification
-
Pattern: Use
jax.tree.map()to transform leaves into operator dicts -
Static Branching with jax.lax.cond:
-
Conditional execution uses
jax.lax.cond(condition, true_fn, false_fn, operands) - Why: Python if statements break tracing,
jax.lax.condis trace-compatible -
Pattern: Define apply_fn and noop_fn, use
jax.lax.condfor selection -
weight_key Data Stripping:
-
When
weight_keyis set, it is stripped from bothdata(inapply()) anddata_shapes(ingenerate_random_params()) - Why: Children's random param trees must match the clean data they receive; a shape mismatch causes vmap failures
- Pattern: Dict comprehension
{k: v for k, v in d.items() if k != weight_key}
These patterns ensure all strategies work correctly inside jax.vmap and jax.jit.
CompositeOperatorConfig
dataclass
¤
CompositeOperatorConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, batch_strategy: str = 'vmap', strategy: CompositionStrategy | None = None, operators: Sequence[OperatorModule] | None = None, merge_strategy: str | None = None, merge_fn: Callable | None = None, merge_axis: int = 0, weights: list[float] | None = None, learnable_weights: bool = False, weight_key: str | None = None, conditions: Sequence[Callable[[PyTree], bool | Array]] | None = None, router: Callable[[PyTree], int | Array] | None = None, default_branch: int | None = None)
Bases: OperatorConfig
Configuration for composite operators.
Inherits from OperatorConfig:
- name: str | None
- stochastic: bool (whether any child is stochastic)
- stream_name: str (for RNG if stochastic)
WEIGHTED_PARALLEL supports three mutually exclusive weight modes:
1. **Static weights** (default): ``weights=[0.5, 0.5]`` — fixed at construction.
2. **Learnable weights**: ``learnable_weights=True`` — stored as ``nnx.Param``,
optimized via gradient descent.
3. **Dynamic external weights**: ``weight_key="op_weights"`` — extracted from
``data[weight_key]`` at each forward call. Enables upstream modules (e.g.,
a Gumbel-Softmax policy) to supply weights that change per call, with
gradients flowing back through the weights to the upstream parameters.
When weight_key is set, the key is stripped from the data dict before
passing to child operators, so children only see the actual data fields.
Attributes:
| Name | Type | Description |
|---|---|---|
strategy |
CompositionStrategy | None
|
Composition strategy to use. |
operators |
Sequence[OperatorModule] | None
|
List of operators for all strategies. |
merge_strategy |
str | None
|
How to merge parallel outputs ("concat", "stack", "sum", "mean", "dict"). |
merge_fn |
Callable | None
|
Custom merge function (overrides merge_strategy). |
merge_axis |
int
|
Axis for stack/concat operations. |
weights |
list[float] | None
|
Weights for weighted parallel (None = equal weights). |
learnable_weights |
bool
|
Whether weights are learnable parameters. |
weight_key |
str | None
|
Key in data dict for external dynamic weights. Mutually exclusive
with |
conditions |
Sequence[Callable[[PyTree], bool | Array]] | None
|
Conditions for conditional strategies (returns JAX arrays). |
router |
Callable[[PyTree], int | Array] | None
|
Router function for branching (returns integer index). |
default_branch |
int | None
|
Default branch index for fallback behavior. |
strategy
class-attribute
instance-attribute
¤
strategy: CompositionStrategy | None = field(default=None)
operators
class-attribute
instance-attribute
¤
operators: Sequence[OperatorModule] | None = field(default=None)
conditions
class-attribute
instance-attribute
¤
precomputed_stats
class-attribute
instance-attribute
¤
CompositeOperatorModule ¤
CompositeOperatorModule(config: CompositeOperatorConfig, *, rngs: Rngs | None = None)
Bases: OperatorModule
Unified composite operator supporting all composition strategies.
Uses the Strategy Pattern internally — each CompositionStrategy enum value
maps to a strategy implementation class (e.g., WeightedParallelStrategy).
For WEIGHTED_PARALLEL with weight_key, the composite extracts weights
from the data dict at each forward call, strips the key from child data, and
delegates to WeightedParallelStrategy for the weighted sum. This enables
differentiable pipelines where an upstream module (e.g., Gumbel-Softmax policy)
supplies per-call weights with full gradient flow.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CompositeOperatorConfig
|
Composite operator configuration |
required |
rngs
|
Rngs | None
|
Optional RNGs for stochastic operators |
None
|
generate_random_params ¤
Generate random parameters for all child operators.
When weight_key is configured, strips that key from data_shapes
before delegating to children. This ensures children's random param trees
match the clean data they receive (without the weight key), preventing
PyTree structure mismatches during vmap.
apply ¤
apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: dict[str, Any] | None = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply composition based on the configured strategy.
For WEIGHTED_PARALLEL with weight_key, extracts weights from
data[weight_key], strips the key from data, and passes clean data
to the strategy. Raises ValueError if the key is missing from data.
add_operator ¤
add_operator(operator: OperatorModule, index: int | None = None) -> None
Add operator to dynamic sequential.
remove_operator ¤
remove_operator(index: int) -> OperatorModule
Remove operator from dynamic sequential.
reorder_operators ¤
Reorder operators in dynamic sequential.
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
get_output_structure ¤
get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]
Declare output PyTree structure for vmap axis specification.
Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_data
|
PyTree
|
Single element data (not batched) |
required |
sample_state
|
PyTree
|
Single element state (not batched) |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (output_data_structure, output_state_structure) with None leaves. |
PyTree
|
The structure (keys/nesting) matters, leaf values are ignored. |
Example override for operator that adds keys
def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state
apply_batch ¤
Process entire batch with vmap and optional RNG generation.
This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.
The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Input batch (Batch[Element] structure) |
required |
stats
|
dict[str, Any] | None
|
Optional statistics (if None, uses get_statistics()) |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
Transformed batch with same structure |
Note
This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.
output_spec ¤
Return the operator's output spec given an input spec.
Most operators (normalization, additive noise, simple element-wise
transforms) do not change shape; the default returns input_spec
unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST
override this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_spec
|
PyTree
|
PyTree of |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree of |
PyTree
|
By default, equal to |