Skip to content

PyTree Utilities¤

JAX pytree manipulation utilities.

See Also¤


datarax.utils.pytree_utils ¤

Utility functions for working with JAX PyTrees and Batches.

This module provides a collection of helper functions to inspect, manipulate, and validate JAX PyTrees and Datarax Batch objects. It encompasses:

  • Type checking utilities (is_array, is_container, is_jax_array)
  • Batch-specific leaf predicates (is_batch_leaf, is_non_jax_leaf)
  • Batch manipulation helpers (split_batch_for_devices, concatenate_batch_sequence)
  • Dimensionality transformations (add/remove batch dimensions)
  • Structure introspection and consistency validation

logger module-attribute ¤

logger = getLogger(__name__)

is_array ¤

is_array(x: Any) -> bool

Check if x is a JAX or numpy array.

Parameters:

Name Type Description Default
x Any

Value to check

required

Returns:

Type Description
bool

True if x is a JAX or numpy array, False otherwise

is_container ¤

is_container(x: Any) -> bool

Check if x is a pytree container type.

Parameters:

Name Type Description Default
x Any

Value to check

required

Returns:

Type Description
bool

True if x is dict, list, or tuple, False otherwise

is_jax_array ¤

is_jax_array(x: Any) -> bool

Check if x is a JAX array or compatible numeric type.

Parameters:

Name Type Description Default
x Any

Value to check

required

Returns:

Type Description
bool

True if x is JAX-compatible, False otherwise

is_batch_leaf ¤

is_batch_leaf(x: Any) -> bool

Check if x should be treated as a leaf for batching operations.

Arrays and non-containers are leaves. Containers (dict, list, tuple) are traversed as pytree structure.

Parameters:

Name Type Description Default
x Any

Value to check

required

Returns:

Type Description
bool

True if x is a leaf, False otherwise

is_non_jax_leaf ¤

is_non_jax_leaf(x: Any) -> bool

Check if x should be treated as a leaf for rebatching operations.

For rebatching, Python lists and tuples are treated as atomic data payloads (e.g., [0, 1, 2] as batch labels) rather than containers. Only dicts are traversed as pytree structure.

Parameters:

Name Type Description Default
x Any

Value to check

required

Returns:

Type Description
bool

True if x should be treated as a leaf, False otherwise

Examples:

batch = {"data": jnp.array([1, 2]), "labels": [0, 1]} jax.tree.map(fn, batch, is_leaf=is_non_jax_leaf)

labels list is treated as a leaf, not traversed¤

get_batch_size ¤

get_batch_size(batch: Batch | dict) -> int | None

Extract batch size from a Batch object or plain dict.

Supports both proper Batch objects (with .batch_size property) and plain dicts containing JAX arrays (infers size from first axis).

Parameters:

Name Type Description Default
batch Batch | dict

Batch object or dict containing JAX arrays

required

Returns:

Type Description
int | None

Batch size, or None if batch size cannot be determined

is_single_element ¤

is_single_element(data: Element | Batch) -> bool

Determine if data is a single element or a batch.

Parameters:

Name Type Description Default
data Element | Batch

Element or Batch object to check

required

Returns:

Type Description
bool

True if single element, False if batch

add_batch_dimension ¤

add_batch_dimension(element: Element) -> Batch

Add batch dimension to single element by creating a Batch.

Parameters:

Name Type Description Default
element Element

Single Element

required

Returns:

Type Description
Batch

Batch containing the single element

remove_batch_dimension ¤

remove_batch_dimension(batch: Batch) -> Element

Remove batch dimension from a batch of size 1.

Parameters:

Name Type Description Default
batch Batch

Batch with size 1

required

Returns:

Type Description
Element

Single element extracted from batch

Raises:

Type Description
ValueError

If batch size is not 1

split_batch_for_devices ¤

split_batch_for_devices(batch: Batch, num_splits: int) -> list[Batch]

Split a batch into multiple smaller batches.

Delegates to Batch.split_for_devices() for consistent implementation.

Parameters:

Name Type Description Default
batch Batch

Batch to split

required
num_splits int

Number of splits to create

required

Returns:

Type Description
list[Batch]

List of smaller batches

Raises:

Type Description
ValueError

If batch size is not divisible by num_splits

concatenate_batch_sequence ¤

concatenate_batch_sequence(batches: list[Batch]) -> Batch

Concatenate multiple batches into a single batch.

Delegates to BatchOps.concatenate_batch_sequence() for consistent implementation.

Parameters:

Name Type Description Default
batches list[Batch]

List of Batch objects with same structure

required

Returns:

Type Description
Batch

Single concatenated batch

Raises:

Type Description
ValueError

If batches list is empty

apply_to_batch_dimension ¤

apply_to_batch_dimension(batch: Batch, fn: Callable[..., Array], axis: int = 0, keepdims: bool = False) -> dict[str, Array]

Apply a reduction function along the batch dimension.

Uses jax.tree.map for idiomatic PyTree traversal.

Parameters:

Name Type Description Default
batch Batch

Batch object

required
fn Callable[..., Array]

Reduction function to apply (e.g., jnp.mean, jnp.std, jnp.sum)

required
axis int

Axis to apply function along (default 0 for batch)

0
keepdims bool

Whether to keep the reduced dimension

False

Returns:

Type Description
dict[str, Array]

Dictionary with function applied along batch dimension to each data field

is_batch_consistent ¤

is_batch_consistent(batch: Batch) -> bool

Validate that all arrays in a batch have consistent batch dimensions.

Parameters:

Name Type Description Default
batch Batch

Batch object to validate

required

Returns:

Type Description
bool

True if batch is consistent, False otherwise

get_pytree_structure_info ¤

get_pytree_structure_info(data: Element | Batch) -> dict[str, Any]

Get information about Element or Batch structure for debugging.

Parameters:

Name Type Description Default
data Element | Batch

Element or Batch to analyze

required

Returns:

Type Description
dict[str, Any]

Dictionary with structure information