Skip to content

Array Sharder¤

Shard arrays across devices for distributed processing.

See Also¤


datarax.sharding.array_sharder ¤

Array sharder implementation for Datarax.

This module provides a unified implementation of the SharderModule that handles sharding of JAX arrays across devices. Supports both static method usage and NNX module instantiation.

logger module-attribute ¤

logger = getLogger(__name__)

ArraySharder ¤

ArraySharder(config: SharderModuleConfig | None = None, *, rngs: Rngs | None = None)

Bases: SharderModule

Unified array sharding implementation for Datarax.

This class provides methods for sharding JAX arrays across devices, with support for both static method usage and NNX module instantiation. Supports logical axis naming and advanced sharding operations.

Parameters:

Name Type Description Default
config SharderModuleConfig | None

Optional configuration with sharding_rules mapping from logical axis names to physical device mesh axis names.

None
rngs Rngs | None

Optional Rngs for random operations (typically not needed for sharders).

None

config instance-attribute ¤

config: SharderModuleConfig = config

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

shard ¤

shard(batch: Batch, sharding: Sharding) -> Batch

Distribute a batch of data across JAX devices.

Parameters:

Name Type Description Default
batch Batch

A batch of data elements (PyTree).

required
sharding Sharding

A JAX Sharding object specifying how to distribute arrays.

required

Returns:

Type Description
Batch

A sharded batch (PyTree of jax.Array objects with the specified sharding).

shard_static staticmethod ¤

shard_static(batch: Batch, sharding: Sharding) -> Batch

Static method to distribute a batch of data across JAX devices.

Parameters:

Name Type Description Default
batch Batch

A batch of data elements (PyTree).

required
sharding Sharding

A JAX Sharding object specifying how to distribute arrays.

required

Returns:

Type Description
Batch

A sharded batch (PyTree of jax.Array objects with the specified sharding).

shard_with_info ¤

shard_with_info(batch: Batch, sharding: Sharding, info: dict[str, Any] | None = None) -> Batch

Return sharded batch with debug information.

Parameters:

Name Type Description Default
batch Batch

A batch of data elements (PyTree).

required
sharding Sharding

A JAX Sharding object specifying how to distribute arrays.

required
info dict[str, Any] | None

Optional dictionary to update with sharding info.

None

Returns:

Type Description
Batch

A sharded batch (PyTree of jax.Array objects with the specified sharding).

shard_with_info_static staticmethod ¤

shard_with_info_static(batch: Batch, sharding: Sharding, info: dict[str, Any] | None = None) -> Batch

Static method to return sharded batch with debug information.

Parameters:

Name Type Description Default
batch Batch

A batch of data elements (PyTree).

required
sharding Sharding

A JAX Sharding object specifying how to distribute arrays.

required
info dict[str, Any] | None

Optional dictionary to update with sharding info.

None

Returns:

Type Description
Batch

A sharded batch (PyTree of jax.Array objects with the specified sharding).

shard_with_logical_names ¤

shard_with_logical_names(batch: Batch, mesh: Mesh, logical_spec: LogicalAxisSpec | PartitionSpec) -> Batch

Shard a batch using logical axis names.

This method allows using more descriptive logical axis names instead of device mesh axis names directly.

Parameters:

Name Type Description Default
batch Batch

A batch of data elements (PyTree).

required
mesh Mesh

The device mesh to use for sharding.

required
logical_spec LogicalAxisSpec | PartitionSpec

A tuple of logical axis names or a PartitionSpec.

required

Returns:

Type Description
Batch

A sharded batch (PyTree of jax.Array objects).

apply_parallel_transform ¤

apply_parallel_transform(batch: Batch, transform_fn: Callable[[Batch], Batch], mesh: Mesh, in_spec: LogicalAxisSpec | PartitionSpec, out_spec: LogicalAxisSpec | PartitionSpec | None = None) -> Batch

Apply a transformation to the batch in parallel across devices.

This is particularly useful for operations that can be performed independently on each shard of the data.

Parameters:

Name Type Description Default
batch Batch

The batch to transform.

required
transform_fn Callable[[Batch], Batch]

The function to apply to the batch.

required
mesh Mesh

The device mesh to use.

required
in_spec LogicalAxisSpec | PartitionSpec

The input sharding specification.

required
out_spec LogicalAxisSpec | PartitionSpec | None

Optional output sharding specification.

None

Returns:

Type Description
Batch

The transformed batch.

create_sharded_param ¤

create_sharded_param(init_fn: Callable, shape: tuple[int, ...], logical_spec: LogicalAxisSpec | PartitionSpec) -> Param

Create a parameter with explicit sharding annotation.

This utility method makes it easier to create model parameters with appropriate sharding annotations for distributed training.

Parameters:

Name Type Description Default
init_fn Callable

The initialization function for the parameter.

required
shape tuple[int, ...]

The shape of the parameter.

required
logical_spec LogicalAxisSpec | PartitionSpec

The logical sharding specification.

required

Returns:

Type Description
Param

An initialized parameter with sharding annotation.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get the current state of the SharderModule for checkpointing.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the SharderModule.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore internal state from a checkpoint.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

get_partition_spec ¤

get_partition_spec(logical_spec: LogicalAxisSpec | PartitionSpec) -> PartitionSpec

Convert a logical axis specification to a physical PartitionSpec.

Parameters:

Name Type Description Default
logical_spec LogicalAxisSpec | PartitionSpec

A tuple of logical axis names or a PartitionSpec.

required

Returns:

Type Description
PartitionSpec

A PartitionSpec using physical device mesh axis names.

get_named_sharding ¤

get_named_sharding(mesh: Mesh, logical_spec: LogicalAxisSpec | PartitionSpec) -> NamedSharding

Create a NamedSharding from a logical axis specification and mesh.

Parameters:

Name Type Description Default
mesh Mesh

The device mesh to use for sharding.

required
logical_spec LogicalAxisSpec | PartitionSpec

A tuple of logical axis names or a PartitionSpec.

required

Returns:

Type Description
NamedSharding

A NamedSharding using the provided mesh and converted partition spec.

parallel_transform ¤

parallel_transform(batch: Batch, transform_fn: Callable[[Batch], Batch], mesh: Mesh, in_spec: LogicalAxisSpec | PartitionSpec, out_spec: LogicalAxisSpec | PartitionSpec | None = None) -> Batch

Apply a transformation to each shard of the batch in parallel.

This uses nnx.shard_map to efficiently process data in parallel across multiple devices.

Parameters:

Name Type Description Default
batch Batch

The batch of data to transform.

required
transform_fn Callable[[Batch], Batch]

A function that takes a batch as input and returns a transformed batch.

required
mesh Mesh

The device mesh to use for sharding.

required
in_spec LogicalAxisSpec | PartitionSpec

The input sharding specification.

required
out_spec LogicalAxisSpec | PartitionSpec | None

Optional output sharding specification. If not provided, the input spec is used.

None

Returns:

Type Description
Batch

The transformed batch with the specified sharding.