Element and Batch Types¤
Core types for individual elements and batched data.
See Also¤
- Core Overview - All core protocols
- Batching - Batch creation
- Element Operator - Element transforms
- Types & Protocols - Type definitions
datarax.core.element_batch ¤
Element and Batch modules following JAX and Flax NNX best practices.
Key design decisions:
- Element uses flax.struct for immutability and automatic pytree registration
- Batch uses Flax NNX Module pattern for state management
- No object dtype arrays (JAX limitation)
- Proper handling of static arguments in JIT compilation
- Efficient vectorized operations without Python loops
Element ¤
Immutable data element with JAX-compatible operations.
Element represents a single data point with:
- data: PyTree structure containing JAX arrays (supports nested dicts)
- state: Dictionary of arbitrary Python values
- metadata: Optional Metadata instance
All operations return new instances (immutable design).
update_state ¤
Update state with partial updates (merge behavior).
update_data ¤
Update data with partial updates (merge behavior).
transform ¤
Transform all data arrays with a function.
Note: Cannot be directly JIT compiled as fn must be static. Use transform_element_jit with static function IDs instead.
with_metadata ¤
Return new Element with updated metadata.
apply_to_data ¤
Apply differentiable transformation to all data arrays.
This method preserves gradients through JAX transformations by applying the function directly to each array in the data dictionary using jax.tree.map.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[[Array], Array]
|
Differentiable function to apply to each array |
required |
Returns:
| Type | Description |
|---|---|
Element
|
New Element with transformed data, preserving state and metadata |
Examples:
element = Element(data={"x": jnp.array([1.0, 2.0])}) scaled = element.apply_to_data(lambda x: x * 2.0)
Gradients flow through the scaling operation¤
Batch ¤
Bases: Module
Batch container using Flax NNX patterns.
Design rationale:
- Stores data as stacked JAX PyTrees for efficiency
- Stores states as stacked JAX PyTrees (enables vmap)
- Stores metadata as Python list (immutable, not vmapped)
- Uses NNX Variables for mutable state management
- All operations are JAX-compatible for JIT compilation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
elements
|
list[Element]
|
List of Element instances |
required |
validate
|
bool
|
Whether to validate consistency |
True
|
valid_mask
instance-attribute
¤
from_parts
classmethod
¤
from_parts(data: PyTree, states: PyTree, metadata_list: list[Any] | None = None, batch_metadata: Any | None = None, batch_state: PyTree | None = None, *, validate: bool = True, valid_mask: Array | None = None) -> Batch
Create Batch directly from pre-built parts with validation.
This is the recommended way to construct batches from transformed data in operators, as it avoids Python loops and validates structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
PyTree with arrays having batch dimension as axis 0 |
required |
states
|
PyTree
|
PyTree with arrays having batch dimension as axis 0 (same structure as data, all leaves must have matching batch size) |
required |
metadata_list
|
list[Any] | None
|
Optional list of metadata, length must match batch size |
None
|
batch_metadata
|
Any | None
|
Optional batch-level metadata (immutable) |
None
|
batch_state
|
PyTree | None
|
Optional batch-level state PyTree (no batch dimension) |
None
|
validate
|
bool
|
If True, validates batch axis consistency and lengths |
True
|
Returns:
| Type | Description |
|---|---|
Batch
|
New Batch instance |
Raises:
| Type | Description |
|---|---|
ValueError
|
If validation fails (mismatched batch sizes, inconsistent shapes) |
Examples:
Simple flat PyTrees¤
data = {"image": jnp.ones((32, 224, 224, 3))} states = {"count": jnp.zeros((32,)), "flag": jnp.ones((32,), dtype=bool)} batch = Batch.from_parts(data, states)
Nested PyTrees¤
data = { "vision": {"image": jnp.ones((32, 224, 224, 3))}, "text": jnp.ones((32, 512)) } states = { "counters": {"augment": jnp.zeros((32,)), "transform": jnp.zeros((32,))}, "score": jnp.ones((32,)) } batch = Batch.from_parts(data, states)
get_elements ¤
Get multiple elements by indices or slice.
split_for_devices ¤
Split batch evenly across devices.
BatchOps ¤
BatchView ¤
Lightweight batch container for the hot iteration path.
A plain Python object (NOT an nnx.Module) that provides the same dict-like interface as Batch but without NNX Variable overhead. Uses slots for minimal memory footprint.
Used in the fused operator chain where we need: - get_data() for adapter materialization - getitem, contains, iter for dict-like access - batch_size for consistency checks - to_batch() for conversion when full NNX Batch features are needed
Creating a BatchView is essentially free (~0μs) compared to Batch which creates 5 nnx.Variable instances (~50-100μs).
conditional_transform ¤
conditional_transform(batch: Batch, true_fn: Callable[[Batch], Batch], false_fn: Callable[[Batch], Batch], condition: Array) -> Batch
Conditional transformation with dynamic condition.
Uses jax.lax.cond for traced boolean conditions.
iterative_transform ¤
Apply iterative transformation.
while_transform ¤
while_transform(batch: Batch, cond_fn: Callable[[Batch], bool], body_fn: Callable[[Batch], Batch], max_iterations: int = 100) -> Batch
Apply while loop transformation using nnx.while_loop.