Element Operator¤
The ElementOperator is Datarax's most commonly used operator for element-level transformations. Unlike MapOperator (which transforms individual array leaves), ElementOperator provides access to the full Element structure - including data, state, and metadata - enabling coordinated transformations across multiple fields.
Key Concepts¤
★ Insight ─────────────────────────────────────
- ElementOperator works with entire
Elementobjects, not individual arrays - User functions receive
fn(element, key) -> elementsignature - Use
element.replace()for immutable updates (Pythonic JAX pattern) - Supports both deterministic and stochastic modes via configuration
─────────────────────────────────────────────────
When to Use ElementOperator¤
| Use Case | Example |
|---|---|
| Coordinated transformations | Flip an image AND its segmentation mask together |
| Multi-field processing | Normalize image based on mask statistics |
| State tracking | Update element state based on transformation |
| Metadata-aware processing | Apply different augmentations based on metadata |
Quick Start¤
import flax.nnx as nnx
from datarax.operators import ElementOperator
from datarax.core.config import ElementOperatorConfig
# Define a transformation function
def normalize(element, key):
"""Normalize image values to [0, 1]."""
new_data = {"image": element.data["image"] / 255.0}
return element.replace(data=new_data)
# Create operator (deterministic mode)
config = ElementOperatorConfig(stochastic=False)
op = ElementOperator(config, fn=normalize, rngs=nnx.Rngs(0))
# Apply to an element
result = op.apply(element.data, element.state, element.metadata)
Stochastic Transformations¤
For random augmentations, use stochastic mode with a stream name:
import jax
def add_noise(element, key):
"""Add random Gaussian noise to image."""
noise = jax.random.normal(key, element.data["image"].shape) * 0.1
new_data = {"image": element.data["image"] + noise}
return element.replace(data=new_data)
config = ElementOperatorConfig(stochastic=True, stream_name="augment")
op = ElementOperator(config, fn=add_noise, rngs=nnx.Rngs(42))
Coordinated Augmentations¤
One of ElementOperator's key strengths is applying the same random decision to multiple fields:
import jax.lax
def flip_both(element, key):
"""Randomly flip image and mask together."""
should_flip = jax.random.uniform(key) < 0.5
new_data = jax.lax.cond(
should_flip,
lambda: {
"image": element.data["image"][..., ::-1],
"mask": element.data["mask"][..., ::-1]
},
lambda: element.data,
)
return element.replace(data=new_data)
config = ElementOperatorConfig(stochastic=True, stream_name="flip")
flip_op = ElementOperator(config, fn=flip_both, rngs=nnx.Rngs(0))
Integration with DAG Pipelines¤
ElementOperator integrates seamlessly with Datarax's DAG execution:
from datarax.pipeline import Pipeline
from datarax.pipeline import Pipeline
# Build a pipeline with ElementOperator
pipeline = (
Pipeline(source=my_source, stages=[normalize_op, flip_op], batch_size=32, rngs=nnx.Rngs(0)))
# Iterate over batches
for batch in pipeline:
train_step(batch)
See Also¤
- Operators Overview - All available operators
- MapOperator - For per-array-leaf transformations
- CompositeOperator - For chaining operators
- DAG Executor - Pipeline execution
- Operators Tutorial - Hands-on examples
API Reference¤
datarax.operators.element_operator ¤
ElementOperator - operator for element-level transformations.
This module provides ElementOperator, which applies user-provided element transformation functions to entire Element structures (data + state + metadata).
Key Difference from MapOperator:
- MapOperator: fn(array_leaf, key) -> array_leaf (per-array-leaf transformation)
- ElementOperator: fn(element, key) -> element (per-element transformation)
Key Features:
- Full element access: User function sees entire Element, can modify data/state/metadata
- Coordinated transformations: Transform multiple fields together
- Deterministic mode: key parameter ignored
- Stochastic mode: key parameter provides per-element randomness
- Uses Element.replace() pattern for immutable updates
ElementOperator ¤
ElementOperator(config: ElementOperatorConfig, fn: Callable[[Element, PRNGKey], Element], *, rngs: Rngs | None = None, name: str | None = None)
Bases: OperatorModule
Unified operator for element-level transformations.
Applies user-provided element transformation function to entire Element structures. Unlike MapOperator (which transforms array leaves), ElementOperator provides access to the full element (data + state + metadata), enabling coordinated transformations.
User Function Signature:
fn(element: Element, key: jax.Array) -> Element
- element: Element with .data, .state, .metadata attributes
- key: JAX random key (use for stochastic ops, ignore for deterministic)
- Returns: New Element (use element.replace() for immutable updates)
Use Cases: 1. Coordinated transformations: Flip image AND mask together 2. State tracking: Update state based on transformation applied 3. Complex augmentation pipelines: Access multiple fields at once 4. Metadata-aware processing: Transform based on metadata values
Examples:
def normalize(element, key): # Deterministic element transformation
new_data = {"value": element.data["value"] / 255.0}
return element.replace(data=new_data)
config = ElementOperatorConfig(stochastic=False)
op = ElementOperator(config, fn=normalize, rngs=rngs)
def add_noise(element, key): # Stochastic element augmentation
noise = jax.random.normal(key, element.data["image"].shape) * 0.1
new_data = {"image": element.data["image"] + noise}
return element.replace(data=new_data)
config = ElementOperatorConfig(stochastic=True, stream_name="augment")
op = ElementOperator(config, fn=add_noise, rngs=rngs)
def flip_both(element, key): # Coordinated augmentation
flip = jax.random.uniform(key) < 0.5
new_data = jax.lax.cond(
flip,
lambda e: {"image": e.data["image"][..., ::-1],
"mask": e.data["mask"][..., ::-1]},
lambda e: e.data,
element
)
return element.replace(data=new_data)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ElementOperatorConfig
|
Operator configuration |
required |
fn
|
Callable[[Element, PRNGKey], Element]
|
User function with signature: fn(element: Element, key: Array) -> Element - Deterministic mode: ignore key parameter - Stochastic mode: use key for randomness |
required |
rngs
|
Rngs | None
|
Random number generators (required if stochastic=True) |
None
|
name
|
str | None
|
Optional name for the operator |
None
|
generate_random_params ¤
Generate random parameters for batch transformation.
For ElementOperator, generates one RNG key per batch element. The user function receives a single key and can split it internally if multiple random operations are needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
PRNGKey
|
JAX random key (single key for entire batch) |
required |
data_shapes
|
PyTree
|
PyTree with same structure as batch.data, containing shapes Examples: {"image": (batch_size, H, W, C)} |
required |
Returns:
| Type | Description |
|---|---|
PRNGKey | None
|
Array of shape (batch_size, 2) - one PRNGKey per element, |
PRNGKey | None
|
or None for deterministic operators. |
apply ¤
apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Any = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply element transformation.
Constructs an Element from data/state/metadata, passes to user function, and extracts results back.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Element data PyTree |
required |
state
|
PyTree
|
Element state PyTree |
required |
metadata
|
dict[str, Any] | None
|
Element metadata dict (unchanged - not vmapped) |
required |
random_params
|
Any
|
RNG key for this element (from generate_random_params) |
None
|
stats
|
dict[str, Any] | None
|
Optional batch statistics (unused) |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, transformed_state, transformed_metadata) |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
get_output_structure ¤
get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]
Declare output PyTree structure for vmap axis specification.
Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_data
|
PyTree
|
Single element data (not batched) |
required |
sample_state
|
PyTree
|
Single element state (not batched) |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (output_data_structure, output_state_structure) with None leaves. |
PyTree
|
The structure (keys/nesting) matters, leaf values are ignored. |
Example override for operator that adds keys
def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state
apply_batch ¤
Process entire batch with vmap and optional RNG generation.
This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.
The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Input batch (Batch[Element] structure) |
required |
stats
|
dict[str, Any] | None
|
Optional statistics (if None, uses get_statistics()) |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
Transformed batch with same structure |
Note
This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.
output_spec ¤
Return the operator's output spec given an input spec.
Most operators (normalization, additive noise, simple element-wise
transforms) do not change shape; the default returns input_spec
unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST
override this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_spec
|
PyTree
|
PyTree of |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree of |
PyTree
|
By default, equal to |