Batch Processing Quick Reference¤
| Metadata | Value |
|---|---|
| Level | Beginner |
| Runtime | ~2 min |
| Prerequisites | Basic Python, numpy |
| Format | Reference card |
| Memory | ~100 MB RAM |
Overview¤
Batching is fundamental to efficient data processing in Datarax. This reference covers the Batch object, how batching works in pipelines, and common iteration patterns.
What is a Batch?¤
A Batch is a Flax NNX Module that holds a collection of data samples stacked along axis 0. It contains three parts:
| Component | Type | Description |
|---|---|---|
data |
dict[str, jax.Array] |
Stacked data arrays (images, labels, etc.) |
states |
dict[str, jax.Array] |
Per-element state arrays (vmapped with data) |
metadata |
list[Metadata] |
Per-element metadata (Python objects, not JIT-compiled) |
Creating Batches¤
From a pipeline (most common)¤
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.pipeline import Pipeline
import numpy as np
from flax import nnx
data = {
"image": np.random.randn(100, 32, 32, 3).astype(np.float32),
"label": np.random.randint(0, 10, size=(100,)),
}
source = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(0))
# Pipeline auto-batches with the specified batch_size
pipeline = Pipeline(source=source, stages=[], batch_size=16, rngs=nnx.Rngs(0))
for batch in pipeline:
print(batch["image"].shape) # (16, 32, 32, 3)
break
From pre-built arrays (direct construction)¤
from datarax.core.element_batch import Batch
import jax.numpy as jnp
batch = Batch.from_parts(
data={"image": jnp.ones((8, 32, 32, 3)), "label": jnp.zeros((8,))},
states={},
)
Accessing Batch Data¤
# Dict-like access (recommended)
images = batch["image"] # jax.Array, shape (B, ...)
labels = batch["label"] # jax.Array, shape (B,)
# Check if key exists
if "mask" in batch:
mask = batch["mask"]
# Get full data dict
data_dict = batch.get_data() # {"image": ..., "label": ...}
# Batch size
n = batch.batch_size # int
Iteration Patterns¤
Full epoch (iterate all data once)¤
pipeline = Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0))
for batch in pipeline:
loss = train_step(batch["image"], batch["label"])
Multiple epochs¤
for epoch in range(num_epochs):
pipeline = Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0))
for batch in pipeline:
loss = train_step(batch["image"], batch["label"])
Limited iteration (first N batches)¤
import itertools
pipeline = Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0))
for batch in itertools.islice(pipeline, 10): # First 10 batches
loss = train_step(batch["image"], batch["label"])
How batch_size Works¤
The Pipeline constructor auto-batches via the batch_size argument:
batch_size=32groups 32 elements into eachBatch- The last batch may be smaller if
num_elements % batch_size != 0 - Set
enforce_batch=Falseto skip auto-batching (advanced use)
# Standard batching
pipeline = Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0))
# No auto-batching (elements yielded individually)
pipeline = Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0)) # enforce_batch=False
# With prefetching (default: 2 batches ahead)
pipeline = Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0)) # prefetch_size=4
Next Steps¤
- Data Loading Quick Reference -- Load data from various sources
- Operators Tutorial -- Transform batch data with operators
- Simple Pipeline -- Complete pipeline example