Operator Cheat Sheet¤
Copy-paste-ready patterns for all Datarax operator types.
MapOperator¤
Applies a function to specific fields in the data dictionary.
from datarax.operators import MapOperator, MapOperatorConfig
# Transform a single field
normalize = MapOperator(
MapOperatorConfig(subtree={"image": None}),
fn=lambda x: x / 255.0,
rngs=nnx.Rngs(0),
)
# Transform multiple fields
scale = MapOperator(
MapOperatorConfig(subtree={"image": None, "mask": None}),
fn=lambda x: x.astype(jnp.float32),
rngs=nnx.Rngs(0),
)
# Full-tree mode (applies to entire data dict)
identity = MapOperator(
MapOperatorConfig(subtree=None),
fn=lambda x: x,
rngs=nnx.Rngs(0),
)
ElementOperator¤
Applies a function to the entire Element (data + state + metadata).
from datarax.operators import ElementOperator, ElementOperatorConfig
# Deterministic element transform
def add_length(element):
text = element.data["text"]
length = jnp.array(text.shape[0])
return element.update_data({"text": text, "length": length})
length_op = ElementOperator(
ElementOperatorConfig(),
fn=add_length,
rngs=nnx.Rngs(0),
)
# Stochastic element transform
def random_crop(element, *, rngs):
key = rngs.augment()
# ... crop logic using key
return element.update_data({"image": cropped})
crop_op = ElementOperator(
ElementOperatorConfig(stochastic=True, stream_name="augment"),
fn=random_crop,
rngs=nnx.Rngs(0),
)
Custom OperatorModule Subclass¤
For operators with learnable parameters or complex state.
from datarax.core.operator import OperatorModule
from datarax.core.config import OperatorConfig
from datarax.core.element_batch import Element
import flax.nnx as nnx
class MyOperator(OperatorModule):
def __init__(self, config, *, rngs):
super().__init__(config, rngs=rngs)
self.scale = nnx.Param(jnp.ones(())) # Learnable parameter
def apply(self, element: Element, *, rngs=None) -> Element:
scaled = element.data["image"] * self.scale.value
return element.update_data({"image": scaled})
op = MyOperator(OperatorConfig(), rngs=nnx.Rngs(0))
Image Operators¤
Built-in operators for common image augmentations.
from datarax.operators.modality.image import (
BrightnessOperator, BrightnessOperatorConfig,
ContrastOperator, ContrastOperatorConfig,
RotationOperator, RotationOperatorConfig,
NoiseOperator, NoiseOperatorConfig,
)
# Brightness adjustment
brightness = BrightnessOperator(
BrightnessOperatorConfig(
field_key="image", brightness_range=(-0.2, 0.2),
stochastic=True, stream_name="brightness",
),
rngs=nnx.Rngs(0),
)
# Contrast adjustment
contrast = ContrastOperator(
ContrastOperatorConfig(
field_key="image", contrast_range=(0.8, 1.2),
stochastic=True, stream_name="contrast",
),
rngs=nnx.Rngs(0),
)
# Random rotation
rotation = RotationOperator(
RotationOperatorConfig(
field_key="image", angle_range=(-15, 15),
stochastic=True, stream_name="rotation",
),
rngs=nnx.Rngs(0),
)
# Gaussian noise
noise = NoiseOperator(
NoiseOperatorConfig(
field_key="image", mode="gaussian", noise_std=0.05,
stochastic=True, stream_name="noise",
),
rngs=nnx.Rngs(0),
)
Stochastic vs Deterministic¤
| Mode | Config | RNG | Use Case |
|---|---|---|---|
| Deterministic | stochastic=False |
Not needed | Normalization, type casting |
| Stochastic | stochastic=True, stream_name="aug" |
Required | Augmentation, dropout |
# Deterministic: no randomness
det_config = OperatorConfig(stochastic=False)
# Stochastic: requires stream_name
stoch_config = OperatorConfig(stochastic=True, stream_name="augment")
batch_strategy: vmap vs scan¤
Control how operators process batch elements.
| Strategy | Memory | Speed | Use When |
|---|---|---|---|
"vmap" (default) |
O(B) | Fast (parallel) | Small operators, training |
"scan" |
O(1) | Slower (sequential) | Memory-heavy operators (CREPE, large CNNs) |
# Default: vmap (parallel, higher memory)
config = OperatorConfig(batch_strategy="vmap")
# Low memory: scan (sequential, O(1) memory)
config = OperatorConfig(batch_strategy="scan")
Composition¤
Chain operators together using CompositeOperatorModule.
from datarax.operators import (
CompositeOperatorModule, CompositeOperatorConfig, CompositionStrategy,
)
# Sequential composition (op1 -> op2 -> op3)
composite = CompositeOperatorModule(
CompositeOperatorConfig(
strategy=CompositionStrategy.SEQUENTIAL,
operators=[brightness, contrast, noise],
),
rngs=nnx.Rngs(0),
)
# Apply with probability
from datarax.operators import ProbabilisticOperator, ProbabilisticOperatorConfig
maybe_noise = ProbabilisticOperator(
ProbabilisticOperatorConfig(operator=noise, probability=0.5),
rngs=nnx.Rngs(0),
)
Using Operators in Pipelines¤
from datarax.pipeline import Pipeline
pipeline = Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0))
# Use stages=[brightness] when constructing the Pipeline instead.
# Use stages=[contrast] when constructing the Pipeline instead.
for batch in pipeline:
augmented_images = batch["image"]