Datarax: High-Performance JAX Data Pipelines¤
Datarax is a high-performance, extensible data pipeline framework specifically engineered for JAX-based machine learning workflows. It simplifies and accelerates the development of efficient and scalable data loading, preprocessing, and augmentation pipelines for JAX, leveraging the full potential of JAX's Just-In-Time (JIT) compilation, automatic differentiation, and hardware acceleration capabilities.
Key Features¤
- High Performance: Leverages JAX's JIT compilation and XLA backend to achieve near-optimal data processing speeds on CPUs, GPUs, and TPUs.
- JAX-Native Design: All core components and operations are designed with JAX's functional programming paradigm and immutable data structures (PyTrees) in mind.
- Scalability: Supports efficient data loading and processing for large datasets and distributed training scenarios, including multi-host and multi-device setups.
- Extensibility: Easily define and integrate custom data sources, transformations, and augmentation operations.
- Usability: Provides a clear, intuitive Python API and a flexible configuration system for defining and managing pipelines.
- Determinism: Pipeline runs are deterministic by default, crucial for reproducibility.
- Complete Feature Set: Supports common operators, advanced transformations, batching, sharding, checkpointing, and caching.
- Ecosystem Integration: Facilitates smooth integration with other JAX libraries like Flax, Optax, and Orbax.
Why Datarax?¤
Datarax's differentiable pipeline architecture enables optimization paradigms that are impossible with traditional data loaders:
-
Learned Augmentation (DADA)
10,000x faster augmentation policy search via gradient descent through datarax's operator library
-
Learned ISP for Detection
End-to-end differentiable image processing pipeline using the DAG the
stages=[...]argument -
DDSP Audio Synthesis
Custom
OperatorModulesubclasses for audio — proving extensibility beyond images
Quick Navigation¤
- Batching - Batch creation and management
- Benchmarking - Performance measurement tools
- Checkpoint - State persistence and recovery
- Command Line Interface - CLI tools
- Config - Configuration management
- Control - Pipeline control flow
- Core Components - Core abstractions
- DAG - Directed acyclic graph execution
- Distributed - Multi-device processing
- Memory - Memory management
- Monitoring - Metrics and observability
- Operators - Data transformation operators
- Performance - Performance optimization
- Root - Root module types
- Samplers - Data sampling strategies
- Sharding - Data sharding
- Sources - Data source adapters
- Utilities - Utility functions
Installation¤
# Install via uv (recommended)
uv pip install datarax
# Install with optional dependencies
uv pip install datarax[data] # For data loading (HF, TFDS)
uv pip install datarax[gpu] # For GPU support
# Or locally for development
pip install -e .
Quick Start¤
Here's a simple example of using Datarax's DAG-based architecture:
import jax
import jax.numpy as jnp
from flax import nnx
from datarax.pipeline import Pipeline
from datarax.pipeline import Pipeline
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig
# 1. Define operations
def normalize(element, key=None):
return element.update_data({"image": element.data["image"] / 255.0})
# 2. Create data source
source_config = MemorySourceConfig()
source = MemorySource(source_config, data=my_data_dict, rngs=nnx.Rngs(0))
# 3. Create operators
normalizer = ElementOperator(
ElementOperatorConfig(),
fn=normalize,
rngs=nnx.Rngs(0)
)
# 4. Build pipeline
pipeline = (
Pipeline(source=source, stages=[normalizer], batch_size=32, rngs=nnx.Rngs(0)))
# 5. Run pipeline
for batch in pipeline:
process(batch)
Documentation Structure¤
- API Reference - Complete API documentation
- Module Documentation - Detailed documentation for each module
- Examples - Usage examples and tutorials
- Migration Guides - Guides for migrating between versions
Contributing¤
To contribute to the documentation:
- Add docstrings to your Python code
- Run the documentation generator:
python scripts/generate_docs.py - Build the documentation:
mkdocs build - Preview locally:
mkdocs serve