Skip to content

Datarax: High-Performance JAX Data Pipelines¤

CI codecov

Datarax is a high-performance, extensible data pipeline framework specifically engineered for JAX-based machine learning workflows. It simplifies and accelerates the development of efficient and scalable data loading, preprocessing, and augmentation pipelines for JAX, leveraging the full potential of JAX's Just-In-Time (JIT) compilation, automatic differentiation, and hardware acceleration capabilities.

Key Features¤

  • High Performance: Leverages JAX's JIT compilation and XLA backend to achieve near-optimal data processing speeds on CPUs, GPUs, and TPUs.
  • JAX-Native Design: All core components and operations are designed with JAX's functional programming paradigm and immutable data structures (PyTrees) in mind.
  • Scalability: Supports efficient data loading and processing for large datasets and distributed training scenarios, including multi-host and multi-device setups.
  • Extensibility: Easily define and integrate custom data sources, transformations, and augmentation operations.
  • Usability: Provides a clear, intuitive Python API and a flexible configuration system for defining and managing pipelines.
  • Determinism: Pipeline runs are deterministic by default, crucial for reproducibility.
  • Complete Feature Set: Supports common operators, advanced transformations, batching, sharding, checkpointing, and caching.
  • Ecosystem Integration: Facilitates smooth integration with other JAX libraries like Flax, Optax, and Orbax.

Why Datarax?¤

Datarax's differentiable pipeline architecture enables optimization paradigms that are impossible with traditional data loaders:

  • Learned Augmentation (DADA)


    10,000x faster augmentation policy search via gradient descent through datarax's operator library

    Advanced Guide

  • Learned ISP for Detection


    End-to-end differentiable image processing pipeline using the DAG the stages=[...] argument

    Advanced Guide

  • DDSP Audio Synthesis


    Custom OperatorModule subclasses for audio — proving extensibility beyond images

    Advanced Guide

Quick Navigation¤

Installation¤

# Install via uv (recommended)
uv pip install datarax

# Install with optional dependencies
uv pip install datarax[data]     # For data loading (HF, TFDS)
uv pip install datarax[gpu]      # For GPU support

# Or locally for development
pip install -e .

Quick Start¤

Here's a simple example of using Datarax's DAG-based architecture:

import jax
import jax.numpy as jnp
from flax import nnx
from datarax.pipeline import Pipeline
from datarax.pipeline import Pipeline
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig

# 1. Define operations
def normalize(element, key=None):
    return element.update_data({"image": element.data["image"] / 255.0})

# 2. Create data source
source_config = MemorySourceConfig()
source = MemorySource(source_config, data=my_data_dict, rngs=nnx.Rngs(0))

# 3. Create operators
normalizer = ElementOperator(
    ElementOperatorConfig(),
    fn=normalize,
    rngs=nnx.Rngs(0)
)

# 4. Build pipeline
pipeline = (
    Pipeline(source=source, stages=[normalizer], batch_size=32, rngs=nnx.Rngs(0)))

# 5. Run pipeline
for batch in pipeline:
    process(batch)

Documentation Structure¤

  • API Reference - Complete API documentation
  • Module Documentation - Detailed documentation for each module
  • Examples - Usage examples and tutorials
  • Migration Guides - Guides for migrating between versions

Contributing¤

To contribute to the documentation:

  1. Add docstrings to your Python code
  2. Run the documentation generator: python scripts/generate_docs.py
  3. Build the documentation: mkdocs build
  4. Preview locally: mkdocs serve