Operators Deep Dive Tutorial¤
| Metadata | Value |
|---|---|
| Level | Intermediate |
| Runtime | ~45 min |
| Prerequisites | Simple Pipeline, Pipeline Tutorial |
| Format | Python + Jupyter |
Overview¤
Master the Datarax operator system - the building blocks for data transformations. This tutorial covers built-in operators, custom operator creation, and advanced composition patterns for building production-ready data pipelines.
What You'll Learn¤
- Understand operator types: deterministic vs stochastic
- Use built-in image augmentation operators (Brightness, Contrast, Noise, Rotation)
- Create custom operators with proper RNG handling
- Select and transform specific data fields
- Compose operators with different strategies (sequential, parallel)
- Apply conditional transformations based on data properties
- Build full augmentation pipelines for training
Coming from PyTorch?¤
| PyTorch | Datarax |
|---|---|
transforms.Normalize(mean, std) |
ElementOperator with custom normalization fn |
transforms.RandomHorizontalFlip(p=0.5) |
ElementOperator(stochastic=True) with flip logic |
transforms.ColorJitter(brightness=0.2) |
BrightnessOperator(brightness_range=(-0.2, 0.2)) |
transforms.RandomRotation(15) |
RotationOperator(angle_range=(-15, 15)) |
transforms.Compose([T1, T2]) |
CompositeOperatorModule with SEQUENTIAL strategy |
Key difference: Datarax operators use JAX random keys and explicit RNG streams for fine-grained reproducibility. Each stochastic operator needs a unique stream_name.
Coming from TensorFlow?¤
| TensorFlow | Datarax |
|---|---|
tf.keras.layers.Normalization() |
ElementOperator with normalization function |
tf.image.random_brightness(image, 0.2) |
BrightnessOperator(brightness_range=(-0.2, 0.2)) |
tf.image.random_contrast(image, 0.8, 1.2) |
ContrastOperator(contrast_range=(0.8, 1.2)) |
tf.image.rot90(image, k=random) |
RotationOperator(angle_range=(0, 360)) |
| Sequential preprocessing layers | CompositeOperatorModule |
Key difference: Datarax operators are JAX-first with JIT compilation support and work with the Element abstraction rather than raw tensors.
Files¤
- Python Script:
examples/core/03_operators_tutorial.py - Jupyter Notebook:
examples/core/03_operators_tutorial.ipynb
Quick Start¤
# Install Datarax with data dependencies
uv pip install "datarax[data]"
# Run the Python script
python examples/core/03_operators_tutorial.py
# Or launch the Jupyter notebook
jupyter lab examples/core/03_operators_tutorial.ipynb
Part 1: Operator Fundamentals¤
Operators are the transformation units in Datarax pipelines. They receive data elements and return transformed elements.
Key Concepts¤
| Concept | Description |
|---|---|
| Deterministic | Same input always produces same output (e.g., normalization) |
| Stochastic | Uses random keys for randomized transformations (e.g., random flip) |
| Element | Single data sample with .data dictionary |
| Batch | Collection of elements processed together |
| Stream Name | Unique identifier for RNG stream in stochastic operators |
import numpy as np
from flax import nnx
from datarax.sources import MemorySource, MemorySourceConfig
# Create sample image data
np.random.seed(42)
num_samples = 100
data = {
"image": np.random.randint(0, 256, (num_samples, 32, 32, 3)).astype(np.float32),
"label": np.random.randint(0, 10, (num_samples,)).astype(np.int32),
"metadata": np.random.rand(num_samples, 4).astype(np.float32),
}
source = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(0))
Terminal Output:
Part 2: ElementOperator - Custom Transformations¤
ElementOperator is the most flexible operator - wrap any function to transform data elements.
Example 1: Deterministic Normalization¤
import jax.numpy as jnp
from datarax.operators import ElementOperator, ElementOperatorConfig
def normalize_image(element, key=None):
"""Normalize image pixels to [0, 1] range."""
del key # Unused - deterministic operator
image = element.data["image"]
normalized = image / 255.0
return element.update_data({"image": normalized})
normalizer = ElementOperator(
ElementOperatorConfig(stochastic=False),
fn=normalize_image,
rngs=nnx.Rngs(0),
)
# Test it
from datarax.pipeline import Pipeline
from datarax.pipeline import Pipeline
pipeline = Pipeline(source=source, stages=[normalizer], batch_size=16, rngs=nnx.Rngs(0))
batch = next(iter(pipeline))
print(f"Range: [{batch['image'].min():.3f}, {batch['image'].max():.3f}]")
Terminal Output:
Example 2: Stochastic Horizontal Flip¤
import jax
import jax.numpy as jnp
def random_flip(element, key):
"""Randomly flip image horizontally."""
flip_key, _ = jax.random.split(key)
should_flip = jax.random.bernoulli(flip_key, 0.5)
image = element.data["image"]
flipped = jax.lax.cond(
should_flip,
lambda x: jnp.flip(x, axis=1),
lambda x: x,
image,
)
return element.update_data({"image": flipped})
flipper = ElementOperator(
ElementOperatorConfig(stochastic=True, stream_name="flip"),
fn=random_flip,
rngs=nnx.Rngs(flip=42),
)
Terminal Output:
Part 3: Built-in Image Operators¤
Datarax provides optimized image augmentation operators. These follow a consistent pattern: Config + Operator.
Brightness Adjustment¤
from datarax.operators.modality.image import BrightnessOperator, BrightnessOperatorConfig
brightness_op = BrightnessOperator(
BrightnessOperatorConfig(
field_key="image",
brightness_range=(-0.3, 0.3), # Additive delta range
stochastic=True,
stream_name="brightness",
),
rngs=nnx.Rngs(brightness=100),
)
Terminal Output:
Contrast Adjustment¤
from datarax.operators.modality.image import ContrastOperator, ContrastOperatorConfig
contrast_op = ContrastOperator(
ContrastOperatorConfig(
field_key="image",
contrast_range=(0.8, 1.2), # Multiplicative factor range
stochastic=True,
stream_name="contrast",
),
rngs=nnx.Rngs(contrast=200),
)
Terminal Output:
Gaussian Noise¤
from datarax.operators.modality.image import NoiseOperator, NoiseOperatorConfig
noise_op = NoiseOperator(
NoiseOperatorConfig(
field_key="image",
mode="gaussian",
noise_std=0.05,
stochastic=True,
stream_name="noise",
),
rngs=nnx.Rngs(noise=300),
)
Terminal Output:
Rotation¤
from datarax.operators.modality.image import RotationOperator, RotationOperatorConfig
rotation_op = RotationOperator(
RotationOperatorConfig(
field_key="image",
angle_range=(-15.0, 15.0), # Degrees
fill_value=0.0, # Fill empty areas with black
),
rngs=nnx.Rngs(0),
)
Terminal Output:
RotationOperator:
- Rotates image by random angle
- Range: [-15°, +15°]
- Uses bilinear interpolation
Part 4: Field Filtering¤
Transform or filter specific fields from the data dictionary.
def filter_fields(element, key=None):
"""Keep only image and label fields."""
del key # Unused - deterministic operator
filtered = {k: v for k, v in element.data.items() if k in ["image", "label"]}
return element.update_data(filtered)
field_filter = ElementOperator(
ElementOperatorConfig(stochastic=False),
fn=filter_fields,
rngs=nnx.Rngs(0),
)
# Test field filtering
source2 = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(1))
pipeline = Pipeline(source=source2, stages=[field_filter], batch_size=8, rngs=nnx.Rngs(0))
batch = next(iter(pipeline))
print(f"Image present: {batch['image'].shape}")
print(f"Label present: {batch['label'].shape}")
print(f"Metadata removed: {'metadata' not in batch}")
Terminal Output:
Part 5: CompositeOperator - Chaining Transforms¤
Chain multiple operators with CompositeOperatorModule. Different strategies control how operators interact.
Composition Strategies¤
| Strategy | Description |
|---|---|
| SEQUENTIAL | Chain: out₁ → in₂ → out₂ → ... |
| PARALLEL | Apply all to same input, merge outputs |
| ENSEMBLE_MEAN | Parallel + average outputs |
from datarax.operators.composite_operator import (
CompositeOperatorConfig,
CompositeOperatorModule,
CompositionStrategy,
)
# Create individual operators for composition
norm_op = ElementOperator(
ElementOperatorConfig(stochastic=False),
fn=normalize_image,
rngs=nnx.Rngs(0),
)
flip_op = ElementOperator(
ElementOperatorConfig(stochastic=True, stream_name="flip"),
fn=random_flip,
rngs=nnx.Rngs(flip=42),
)
# Sequential composition: normalize → flip
sequential_augment = CompositeOperatorModule(
CompositeOperatorConfig(
strategy=CompositionStrategy.SEQUENTIAL,
operators=[norm_op, flip_op],
stochastic=True,
stream_name="seq_augment",
),
rngs=nnx.Rngs(seq_augment=500),
)
Terminal Output:
Test the composite operator:
source3 = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(2))
pipeline = Pipeline(source=source3, stages=[sequential_augment], batch_size=16, rngs=nnx.Rngs(0))
batch = next(iter(pipeline))
print(f"Image shape: {batch['image'].shape}")
print(f"Image range: [{batch['image'].min():.3f}, {batch['image'].max():.3f}]")
Terminal Output:
Part 6: Building a Full Augmentation Pipeline¤
Combine everything into a production-ready augmentation pipeline.
flowchart LR
subgraph Source["Data Source"]
MS[MemorySource<br/>100 samples]
end
subgraph Pipeline["Augmentation Pipeline"]
FS[Pipeline<br/>batch_size=32]
N[Normalizer<br/>deterministic]
F[Flipper<br/>stochastic]
B[Brightness<br/>stochastic]
end
subgraph Output["Output"]
OUT[Batched Data<br/>32 samples/batch]
end
MS --> FS --> N --> F --> B --> OUT
# Create fresh operators for the full pipeline
normalizer = ElementOperator(
ElementOperatorConfig(stochastic=False),
fn=normalize_image,
rngs=nnx.Rngs(0),
)
flipper = ElementOperator(
ElementOperatorConfig(stochastic=True, stream_name="flip"),
fn=random_flip,
rngs=nnx.Rngs(flip=42),
)
brightness = BrightnessOperator(
BrightnessOperatorConfig(
field_key="image",
brightness_range=(-0.2, 0.2),
stochastic=True,
stream_name="brightness",
),
rngs=nnx.Rngs(brightness=100),
)
# Build pipeline with chained operators
source4 = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(3))
full_pipeline = (
Pipeline(source=source4, stages=[normalizer, flipper, brightness], batch_size=32, rngs=nnx.Rngs(0))
)
print("Full augmentation pipeline:")
print(" Source → Normalize → Flip → Brightness → Output")
Terminal Output:
Process and collect statistics:
stats = {"batches": 0, "samples": 0, "mean_values": []}
for batch in full_pipeline:
stats["batches"] += 1
stats["samples"] += batch["image"].shape[0]
stats["mean_values"].append(float(batch["image"].mean()))
print(f"\nBatches: {stats['batches']}")
print(f"Samples: {stats['samples']}")
print(f"Mean pixel value: {sum(stats['mean_values']) / len(stats['mean_values']):.4f}")
Terminal Output:
Part 7: Custom Operator Patterns¤
Best practices for creating robust custom operators.
Pattern 1: Multi-field Transformation¤
Apply the same random transform to image and corresponding mask:
def augment_image_and_mask(element, key):
"""Apply same random transform to image and corresponding mask."""
key1, _ = jax.random.split(key)
# Random rotation angle
angle = jax.random.uniform(key1, minval=-15, maxval=15)
# Apply to both fields (simplified - real rotation would use jax.scipy)
image = element.data["image"]
# In production, apply actual rotation here
return element.update_data({"image": image, "rotation_angle": angle})
Pattern 2: Conditional Transformation¤
Apply augmentation only to certain samples based on metadata:
def conditional_augment(element, key):
"""Apply augmentation only to certain samples based on metadata."""
key1, _ = jax.random.split(key)
image = element.data["image"]
label = element.data.get("label", 0)
# Apply stronger augmentation to minority classes (e.g., label > 5)
strength = jax.lax.cond(
label > 5,
lambda: 0.2, # Strong augmentation
lambda: 0.05, # Weak augmentation
)
noise = jax.random.normal(key1, image.shape) * strength
augmented = jnp.clip(image + noise, 0.0, 1.0)
return element.update_data({"image": augmented})
conditional_op = ElementOperator(
ElementOperatorConfig(stochastic=True, stream_name="cond"),
fn=conditional_augment,
rngs=nnx.Rngs(cond=999),
)
Terminal Output:
Results Summary¤
| Operator Type | Use Case | Stochastic | Config Required |
|---|---|---|---|
| ElementOperator | Custom transforms | Configurable | ElementOperatorConfig |
| BrightnessOperator | Image brightness | Yes | BrightnessOperatorConfig |
| ContrastOperator | Image contrast | Yes | ContrastOperatorConfig |
| NoiseOperator | Add noise | Yes | NoiseOperatorConfig |
| RotationOperator | Rotate images | Yes | RotationOperatorConfig |
| CompositeOperator | Chain operators | Depends on children | CompositeOperatorConfig |
Key Takeaways¤
- Deterministic operators: Use
stochastic=False, ignorekeyparameter - Stochastic operators: Use
stochastic=True, splitkeyfor each random op - Composition: Use
CompositionStrategy.SEQUENTIALfor chained transforms - Field targeting: Image operators use
field_keyparameter - RNG management: Each stochastic operator needs unique
stream_name - Custom patterns: Leverage JAX control flow (
jax.lax.cond) for conditional logic
Next Steps¤
- Augmentation Quick Reference - Quick guide to image operators
- CIFAR-10 Quick Reference - Apply operators to real data
- MixUp/CutMix Tutorial - Advanced augmentation techniques
- Sharding Guide - Distributed training with operators
- API Reference: Operators - Complete operator API documentation