PyTree Utilities¤
JAX pytree manipulation utilities.
See Also¤
- Utils Overview - All utilities
- NNX Best Practices - JAX patterns
- Checkpointing - Pytree serialization
- Sharding - Pytree sharding
datarax.utils.pytree_utils ¤
Utility functions for working with JAX PyTrees and Batches.
This module provides a collection of helper functions to inspect, manipulate, and validate JAX PyTrees and Datarax Batch objects. It encompasses:
- Type checking utilities (is_array, is_container, is_jax_array)
- Batch-specific leaf predicates (is_batch_leaf, is_non_jax_leaf)
- Batch manipulation helpers (split_batch_for_devices, concatenate_batch_sequence)
- Dimensionality transformations (add/remove batch dimensions)
- Structure introspection and consistency validation
is_array ¤
is_container ¤
is_jax_array ¤
is_batch_leaf ¤
Check if x should be treated as a leaf for batching operations.
Arrays and non-containers are leaves. Containers (dict, list, tuple) are traversed as pytree structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Any
|
Value to check |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if x is a leaf, False otherwise |
is_non_jax_leaf ¤
Check if x should be treated as a leaf for rebatching operations.
For rebatching, Python lists and tuples are treated as atomic data payloads (e.g., [0, 1, 2] as batch labels) rather than containers. Only dicts are traversed as pytree structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Any
|
Value to check |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if x should be treated as a leaf, False otherwise |
Examples:
batch = {"data": jnp.array([1, 2]), "labels": [0, 1]} jax.tree.map(fn, batch, is_leaf=is_non_jax_leaf)
labels list is treated as a leaf, not traversed¤
get_batch_size ¤
Extract batch size from a Batch object or plain dict.
Supports both proper Batch objects (with .batch_size property) and plain dicts containing JAX arrays (infers size from first axis).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch | dict
|
Batch object or dict containing JAX arrays |
required |
Returns:
| Type | Description |
|---|---|
int | None
|
Batch size, or None if batch size cannot be determined |
is_single_element ¤
add_batch_dimension ¤
remove_batch_dimension ¤
Remove batch dimension from a batch of size 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Batch with size 1 |
required |
Returns:
| Type | Description |
|---|---|
Element
|
Single element extracted from batch |
Raises:
| Type | Description |
|---|---|
ValueError
|
If batch size is not 1 |
split_batch_for_devices ¤
Split a batch into multiple smaller batches.
Delegates to Batch.split_for_devices() for consistent implementation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Batch to split |
required |
num_splits
|
int
|
Number of splits to create |
required |
Returns:
| Type | Description |
|---|---|
list[Batch]
|
List of smaller batches |
Raises:
| Type | Description |
|---|---|
ValueError
|
If batch size is not divisible by num_splits |
concatenate_batch_sequence ¤
Concatenate multiple batches into a single batch.
Delegates to BatchOps.concatenate_batch_sequence() for consistent implementation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batches
|
list[Batch]
|
List of Batch objects with same structure |
required |
Returns:
| Type | Description |
|---|---|
Batch
|
Single concatenated batch |
Raises:
| Type | Description |
|---|---|
ValueError
|
If batches list is empty |
apply_to_batch_dimension ¤
apply_to_batch_dimension(batch: Batch, fn: Callable[..., Array], axis: int = 0, keepdims: bool = False) -> dict[str, Array]
Apply a reduction function along the batch dimension.
Uses jax.tree.map for idiomatic PyTree traversal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Batch object |
required |
fn
|
Callable[..., Array]
|
Reduction function to apply (e.g., jnp.mean, jnp.std, jnp.sum) |
required |
axis
|
int
|
Axis to apply function along (default 0 for batch) |
0
|
keepdims
|
bool
|
Whether to keep the reduced dimension |
False
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dictionary with function applied along batch dimension to each data field |