Skip to content

Element and Batch Types¤

Core types for individual elements and batched data.

See Also¤


datarax.core.element_batch ¤

Element and Batch modules following JAX and Flax NNX best practices.

Key design decisions:

  • Element uses flax.struct for immutability and automatic pytree registration
  • Batch uses Flax NNX Module pattern for state management
  • No object dtype arrays (JAX limitation)
  • Proper handling of static arguments in JIT compilation
  • Efficient vectorized operations without Python loops

logger module-attribute ¤

logger = getLogger(__name__)

Element ¤

Immutable data element with JAX-compatible operations.

Element represents a single data point with:

  • data: PyTree structure containing JAX arrays (supports nested dicts)
  • state: Dictionary of arbitrary Python values
  • metadata: Optional Metadata instance

All operations return new instances (immutable design).

data class-attribute instance-attribute ¤

data: PyTree = field(default_factory=dict)

state class-attribute instance-attribute ¤

state: dict[str, Any] = field(default_factory=dict)

metadata class-attribute instance-attribute ¤

metadata: Metadata | None = field(default=None)

update_state ¤

update_state(updates: dict[str, Any]) -> Element

Update state with partial updates (merge behavior).

update_data ¤

update_data(updates: dict[str, Array]) -> Element

Update data with partial updates (merge behavior).

transform ¤

transform(fn: Callable[[Array], Array]) -> Element

Transform all data arrays with a function.

Note: Cannot be directly JIT compiled as fn must be static. Use transform_element_jit with static function IDs instead.

with_metadata ¤

with_metadata(metadata: Metadata) -> Element

Return new Element with updated metadata.

apply_to_data ¤

apply_to_data(fn: Callable[[Array], Array]) -> Element

Apply differentiable transformation to all data arrays.

This method preserves gradients through JAX transformations by applying the function directly to each array in the data dictionary using jax.tree.map.

Parameters:

Name Type Description Default
fn Callable[[Array], Array]

Differentiable function to apply to each array

required

Returns:

Type Description
Element

New Element with transformed data, preserving state and metadata

Examples:

element = Element(data={"x": jnp.array([1.0, 2.0])}) scaled = element.apply_to_data(lambda x: x * 2.0)

Gradients flow through the scaling operation¤

replace ¤

replace(**kwargs: Any) -> Element

Return a new Element with specified fields replaced.

This method is provided by @struct.dataclass at runtime. The stub exists solely for static type-checking support.

Batch ¤

Batch(elements: list[Element], validate: bool = True)

Bases: Module

Batch container using Flax NNX patterns.

Design rationale:

  • Stores data as stacked JAX PyTrees for efficiency
  • Stores states as stacked JAX PyTrees (enables vmap)
  • Stores metadata as Python list (immutable, not vmapped)
  • Uses NNX Variables for mutable state management
  • All operations are JAX-compatible for JIT compilation

Parameters:

Name Type Description Default
elements list[Element]

List of Element instances

required
validate bool

Whether to validate consistency

True

batch_size instance-attribute ¤

batch_size = len(elements)

data instance-attribute ¤

data = Variable(batched_data)

states instance-attribute ¤

states = Variable(batched_states)

batch_state instance-attribute ¤

batch_state = Variable({})

valid_mask instance-attribute ¤

valid_mask: Variable[Array] = Variable(ones((batch_size,), dtype=bool_))

from_parts classmethod ¤

from_parts(data: PyTree, states: PyTree, metadata_list: list[Any] | None = None, batch_metadata: Any | None = None, batch_state: PyTree | None = None, *, validate: bool = True, valid_mask: Array | None = None) -> Batch

Create Batch directly from pre-built parts with validation.

This is the recommended way to construct batches from transformed data in operators, as it avoids Python loops and validates structure.

Parameters:

Name Type Description Default
data PyTree

PyTree with arrays having batch dimension as axis 0

required
states PyTree

PyTree with arrays having batch dimension as axis 0 (same structure as data, all leaves must have matching batch size)

required
metadata_list list[Any] | None

Optional list of metadata, length must match batch size

None
batch_metadata Any | None

Optional batch-level metadata (immutable)

None
batch_state PyTree | None

Optional batch-level state PyTree (no batch dimension)

None
validate bool

If True, validates batch axis consistency and lengths

True

Returns:

Type Description
Batch

New Batch instance

Raises:

Type Description
ValueError

If validation fails (mismatched batch sizes, inconsistent shapes)

Examples:

Simple flat PyTrees¤

data = {"image": jnp.ones((32, 224, 224, 3))} states = {"count": jnp.zeros((32,)), "flag": jnp.ones((32,), dtype=bool)} batch = Batch.from_parts(data, states)

Nested PyTrees¤

data = { "vision": {"image": jnp.ones((32, 224, 224, 3))}, "text": jnp.ones((32, 512)) } states = { "counters": {"augment": jnp.zeros((32,)), "transform": jnp.zeros((32,))}, "score": jnp.ones((32,)) } batch = Batch.from_parts(data, states)

get_element ¤

get_element(index: int) -> Element

Extract single element at index.

get_elements ¤

get_elements(indices: slice | list[int]) -> list[Element]

Get multiple elements by indices or slice.

slice ¤

slice(start: int, end: int) -> Batch

Create new batch from slice of elements (O(1) view).

split_for_devices ¤

split_for_devices(num_devices: int) -> list[Batch]

Split batch evenly across devices.

compute_stats ¤

compute_stats() -> dict[str, Array]

Compute statistics over batch dimension.

get_data ¤

get_data() -> dict[str, Array]

Get batched data dictionary.

get_states ¤

get_states() -> list[dict[str, Any]]

Get list of all states.

get_batch_state ¤

get_batch_state() -> dict[str, Any]

Get batch-level state.

get_batch_metadata ¤

get_batch_metadata() -> Metadata | None

Get batch-level metadata.

set_batch_metadata ¤

set_batch_metadata(metadata: Metadata) -> None

Set batch-level metadata.

update_batch_state ¤

update_batch_state(updates: dict[str, Any]) -> None

Update batch-level state (merge behavior).

BatchOps ¤

Utility operations for batches.

select_batch_rows staticmethod ¤

select_batch_rows(batch: Batch, mask: Array) -> Batch

Filter batch using boolean mask.

Uses JAX indexing on PyTree structures for efficiency.

concatenate_batch_sequence staticmethod ¤

concatenate_batch_sequence(batches: list[Batch]) -> Batch

Concatenate multiple batches.

update_batch_inplace staticmethod ¤

update_batch_inplace(batch: Batch, data_updates: dict[str, Array]) -> Batch

Update batch data in place.

BatchView ¤

BatchView(data: dict, states: dict, batch_size: int)

Lightweight batch container for the hot iteration path.

A plain Python object (NOT an nnx.Module) that provides the same dict-like interface as Batch but without NNX Variable overhead. Uses slots for minimal memory footprint.

Used in the fused operator chain where we need: - get_data() for adapter materialization - getitem, contains, iter for dict-like access - batch_size for consistency checks - to_batch() for conversion when full NNX Batch features are needed

Creating a BatchView is essentially free (~0μs) compared to Batch which creates 5 nnx.Variable instances (~50-100μs).

batch_size instance-attribute ¤

batch_size = batch_size

get_data ¤

get_data() -> dict

Get batched data dictionary (same interface as Batch).

to_batch ¤

to_batch() -> Batch

Convert to full NNX Batch when needed.

conditional_transform ¤

conditional_transform(batch: Batch, true_fn: Callable[[Batch], Batch], false_fn: Callable[[Batch], Batch], condition: Array) -> Batch

Conditional transformation with dynamic condition.

Uses jax.lax.cond for traced boolean conditions.

iterative_transform ¤

iterative_transform(batch: Batch, fn: Callable[[Batch, int], Batch], num_iterations: int) -> Batch

Apply iterative transformation.

while_transform ¤

while_transform(batch: Batch, cond_fn: Callable[[Batch], bool], body_fn: Callable[[Batch], Batch], max_iterations: int = 100) -> Batch

Apply while loop transformation using nnx.while_loop.

create_element ¤

create_element(data: dict[str, Array] | None = None, state: dict[str, Any] | None = None, metadata: Metadata | None = None) -> Element

Create element with defaults.

create_batch_from_arrays ¤

create_batch_from_arrays(data: dict[str, Array], states: list[dict[str, Any]] | None = None, metadata_list: list[Metadata] | None = None) -> Batch

Create batch directly from pre-stacked arrays.