CIFAR-10 Pipeline Quick Reference¤
| Metadata | Value |
|---|---|
| Level | Beginner |
| Runtime | ~5 min |
| Prerequisites | Basic Datarax pipeline, TFDS setup |
| Format | Python + Jupyter |
| Memory | ~500 MB RAM |
Overview¤
This quick reference demonstrates loading and processing CIFAR-10 from TensorFlow Datasets (TFDS). CIFAR-10 is a classic benchmark dataset containing 60,000 32x32 color images in 10 classes, making it ideal for learning image classification pipelines.
What You'll Learn¤
- Load CIFAR-10 using
TFDSEagerSourcewith proper configuration - Apply standard CIFAR-10 normalization (ImageNet-style statistics)
- Build a batched pipeline ready for training
- Understand the data shapes and preprocessing workflow
- Verify preprocessing with statistical checks
Coming from PyTorch?¤
| PyTorch | Datarax |
|---|---|
datasets.CIFAR10(root, train=True) |
TFDSEagerSource(TFDSEagerConfig(name="cifar10", split="train")) |
transforms.ToTensor() |
Automatic conversion to JAX arrays |
transforms.Normalize(mean, std) |
ElementOperator with custom normalization fn |
DataLoader(dataset, batch_size=32, shuffle=True) |
Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0)) with shuffle config |
Key difference: Datarax uses TFDS for dataset access and JAX arrays natively. Normalization constants are identical to PyTorch's standard values.
Coming from TensorFlow?¤
| TensorFlow | Datarax |
|---|---|
tfds.load("cifar10", split="train") |
TFDSEagerSource(TFDSEagerConfig(name="cifar10", split="train")) |
dataset.batch(32).prefetch(2) |
Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0)) |
tf.keras.layers.Rescaling(1./255) |
ElementOperator with division by 255 |
tf.keras.layers.Normalization() |
ElementOperator with mean/std normalization |
Key difference: Datarax provides JAX arrays and integrates with Flax/NNX for training. The pipeline API is more functional.
Files¤
- Python Script:
examples/core/04_cifar10_quickref.py - Jupyter Notebook:
examples/core/04_cifar10_quickref.ipynb
Quick Start¤
# Install datarax with TFDS support
uv pip install "datarax[tfds]"
# Run the Python script
python examples/core/04_cifar10_quickref.py
# Or launch the Jupyter notebook
jupyter lab examples/core/04_cifar10_quickref.ipynb
Note: First run downloads CIFAR-10 (~170 MB).
CIFAR-10 Preprocessing Constants¤
Standard normalization values for CIFAR-10, computed from the training set. Using these values ensures compatibility with pretrained models and published results.
| Statistic | R | G | B |
|---|---|---|---|
| Mean | 0.4914 | 0.4822 | 0.4465 |
| Std | 0.2470 | 0.2435 | 0.2616 |
import jax.numpy as jnp
# CIFAR-10 normalization constants
CIFAR10_MEAN = jnp.array([0.4914, 0.4822, 0.4465])
CIFAR10_STD = jnp.array([0.2470, 0.2435, 0.2616])
# Class names for reference
CIFAR10_CLASSES = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck",
]
print("CIFAR-10 classes:", CIFAR10_CLASSES)
Terminal Output:
CIFAR-10 classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Step 1: GPU Memory Configuration¤
Prevent TensorFlow from using GPU (reserved for JAX training):
import os
# GPU Memory Configuration - prevent TensorFlow from using GPU
os.environ["CUDA_VISIBLE_DEVICES_FOR_TF"] = ""
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.config.set_visible_devices([], "GPU")
Step 2: Create TFDS Data Source¤
Configure TFDSEagerSource to load CIFAR-10 training split. We use a subset for this quick reference to keep runtime short.
from flax import nnx
from datarax.sources import TFDSEagerConfig, TFDSEagerSource
# Load CIFAR-10 training data (subset for quick demo)
config = TFDSEagerConfig(
name="cifar10",
split="train[:1000]", # First 1000 samples for demo
shuffle=True,
seed=42,
exclude_keys={"id"}, # Exclude non-numeric fields
)
source = TFDSEagerSource(config, rngs=nnx.Rngs(42))
print(f"Dataset: CIFAR-10")
print(f"Samples: {len(source)}")
print(f"Classes: {len(CIFAR10_CLASSES)}")
Terminal Output:
Step 3: Define Preprocessing¤
Standard CIFAR-10 preprocessing: 1. Convert uint8 [0, 255] to float32 [0, 1] 2. Apply channel-wise normalization with CIFAR-10 statistics
from datarax.operators import ElementOperator, ElementOperatorConfig
def preprocess_cifar10(element, key=None):
"""Normalize CIFAR-10 images to standard statistics."""
del key # Unused - deterministic operator
image = element.data["image"]
# Convert to float32 and scale to [0, 1]
image = image.astype(jnp.float32) / 255.0
# Apply CIFAR-10 normalization: (x - mean) / std
image = (image - CIFAR10_MEAN) / CIFAR10_STD
return element.update_data({"image": image})
normalizer = ElementOperator(
ElementOperatorConfig(stochastic=False),
fn=preprocess_cifar10,
rngs=nnx.Rngs(0),
)
print("Created CIFAR-10 normalizer with standard statistics")
Terminal Output:
Step 4: Build Pipeline¤
Chain source and preprocessing into a batched pipeline. Batch size of 32 is standard for CIFAR-10 training.
flowchart LR
subgraph Source["TFDS Source"]
T[TFDSEagerSource<br/>CIFAR-10<br/>1000 samples]
end
subgraph Pipeline["Pipeline"]
FS[Pipeline<br/>batch_size=32]
N[Normalizer<br/>(x - mean) / std]
end
subgraph Output["Output"]
B[Batched Data<br/>(32, 32, 32, 3)]
end
T --> FS --> N --> B
from datarax.pipeline import Pipeline
from datarax.pipeline import Pipeline
# Build the training pipeline
batch_size = 32
pipeline = Pipeline(source=source, stages=[normalizer], batch_size=batch_size, rngs=nnx.Rngs(0))
print("Pipeline: TFDSEagerSource(CIFAR-10) -> Normalize -> Output")
print(f"Batch size: {batch_size}")
print(f"Batches per epoch: {len(source) // batch_size}")
Terminal Output:
Step 5: Iterate Through Data¤
Process batches and verify the preprocessing is correct. Normalized data should have approximately zero mean and unit variance.
# Process and verify batches
print("\nProcessing batches:")
all_means = []
all_stds = []
for i, batch in enumerate(pipeline):
if i >= 5: # Show first 5 batches
break
image_batch = batch["image"]
label_batch = batch["label"]
# Compute per-channel statistics
batch_mean = image_batch.mean(axis=(0, 1, 2))
batch_std = image_batch.std(axis=(0, 1, 2))
all_means.append(batch_mean)
all_stds.append(batch_std)
if i < 3: # Print details for first 3 batches
print(f"Batch {i}:")
print(f" Image: shape={image_batch.shape}, dtype={image_batch.dtype}")
print(f" Labels: {label_batch[:8]}... (first 8)")
print(f" Per-channel mean: [{batch_mean[0]:.3f}, {batch_mean[1]:.3f}, {batch_mean[2]:.3f}]")
Terminal Output:
Processing batches:
Batch 0:
Image: shape=(32, 32, 32, 3), dtype=float32
Labels: [6 9 9 4 1 1 2 7]... (first 8)
Per-channel mean: [-0.012, 0.034, -0.089]
Batch 1:
Image: shape=(32, 32, 32, 3), dtype=float32
Labels: [3 5 8 7 0 4 5 3]... (first 8)
Per-channel mean: [0.045, -0.021, 0.012]
Batch 2:
Image: shape=(32, 32, 32, 3), dtype=float32
Labels: [2 1 6 8 9 0 4 2]... (first 8)
Per-channel mean: [-0.089, 0.015, -0.034]
Aggregate statistics across batches:
mean_of_means = jnp.stack(all_means).mean(axis=0)
mean_of_stds = jnp.stack(all_stds).mean(axis=0)
print("\nAggregate Statistics (should be ~0 mean, ~1 std):")
print(f" Mean across batches: [{mean_of_means[0]:.3f}, {mean_of_means[1]:.3f}, {mean_of_means[2]:.3f}]")
print(f" Std across batches: [{mean_of_stds[0]:.3f}, {mean_of_stds[1]:.3f}, {mean_of_stds[2]:.3f}]")
Terminal Output:
Aggregate Statistics (should be ~0 mean, ~1 std):
Mean across batches: [-0.015, 0.009, -0.037]
Std across batches: [0.987, 1.012, 0.995]
Results Summary¤
| Component | Description |
|---|---|
| Dataset | CIFAR-10 (1000 samples for demo) |
| Image Shape | (32, 32, 3) RGB |
| Batch Size | 32 |
| Normalization | Channel-wise with CIFAR-10 statistics |
| Output Range | Approximately N(0, 1) per channel |
Data Format¤
batch = {
"image": Array[32, 32, 32, 3], # (batch, height, width, channels)
"label": Array[32] # (batch,) integer labels 0-9
}
Why Normalize?¤
- Faster convergence: Normalized inputs improve gradient flow during training
- Compatibility: Matches pretrained model expectations (e.g., ResNet, VGG)
- Numerical stability: Prevents overflow/underflow in deep networks
- Consistent scale: All channels have similar variance, preventing bias
Next Steps¤
- Augmentation Quick Reference - Add image operators for training
- Operators Tutorial - Deep dive into custom operators
- MixUp/CutMix Tutorial - Advanced batch augmentation
- Full Training Example - Complete training workflow
- API Reference: TFDSEagerSource - Complete TFDS API documentation