Array Sharder¤
Shard arrays across devices for distributed processing.
See Also¤
- Sharding Overview - All sharding tools
- JAX Process Sharder - Multi-host sharding
- Distributed - Distributed training
- Sharding Quick Reference
datarax.sharding.array_sharder ¤
Array sharder implementation for Datarax.
This module provides a unified implementation of the SharderModule that handles sharding of JAX arrays across devices. Supports both static method usage and NNX module instantiation.
ArraySharder ¤
ArraySharder(config: SharderModuleConfig | None = None, *, rngs: Rngs | None = None)
Bases: SharderModule
Unified array sharding implementation for Datarax.
This class provides methods for sharding JAX arrays across devices, with support for both static method usage and NNX module instantiation. Supports logical axis naming and advanced sharding operations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SharderModuleConfig | None
|
Optional configuration with sharding_rules mapping from logical axis names to physical device mesh axis names. |
None
|
rngs
|
Rngs | None
|
Optional Rngs for random operations (typically not needed for sharders). |
None
|
shard ¤
Distribute a batch of data across JAX devices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
A batch of data elements (PyTree). |
required |
sharding
|
Sharding
|
A JAX Sharding object specifying how to distribute arrays. |
required |
Returns:
| Type | Description |
|---|---|
Batch
|
A sharded batch (PyTree of jax.Array objects with the specified sharding). |
shard_static
staticmethod
¤
Static method to distribute a batch of data across JAX devices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
A batch of data elements (PyTree). |
required |
sharding
|
Sharding
|
A JAX Sharding object specifying how to distribute arrays. |
required |
Returns:
| Type | Description |
|---|---|
Batch
|
A sharded batch (PyTree of jax.Array objects with the specified sharding). |
shard_with_info ¤
Return sharded batch with debug information.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
A batch of data elements (PyTree). |
required |
sharding
|
Sharding
|
A JAX Sharding object specifying how to distribute arrays. |
required |
info
|
dict[str, Any] | None
|
Optional dictionary to update with sharding info. |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
A sharded batch (PyTree of jax.Array objects with the specified sharding). |
shard_with_info_static
staticmethod
¤
shard_with_info_static(batch: Batch, sharding: Sharding, info: dict[str, Any] | None = None) -> Batch
Static method to return sharded batch with debug information.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
A batch of data elements (PyTree). |
required |
sharding
|
Sharding
|
A JAX Sharding object specifying how to distribute arrays. |
required |
info
|
dict[str, Any] | None
|
Optional dictionary to update with sharding info. |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
A sharded batch (PyTree of jax.Array objects with the specified sharding). |
shard_with_logical_names ¤
shard_with_logical_names(batch: Batch, mesh: Mesh, logical_spec: LogicalAxisSpec | PartitionSpec) -> Batch
Shard a batch using logical axis names.
This method allows using more descriptive logical axis names instead of device mesh axis names directly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
A batch of data elements (PyTree). |
required |
mesh
|
Mesh
|
The device mesh to use for sharding. |
required |
logical_spec
|
LogicalAxisSpec | PartitionSpec
|
A tuple of logical axis names or a PartitionSpec. |
required |
Returns:
| Type | Description |
|---|---|
Batch
|
A sharded batch (PyTree of jax.Array objects). |
apply_parallel_transform ¤
apply_parallel_transform(batch: Batch, transform_fn: Callable[[Batch], Batch], mesh: Mesh, in_spec: LogicalAxisSpec | PartitionSpec, out_spec: LogicalAxisSpec | PartitionSpec | None = None) -> Batch
Apply a transformation to the batch in parallel across devices.
This is particularly useful for operations that can be performed independently on each shard of the data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
The batch to transform. |
required |
transform_fn
|
Callable[[Batch], Batch]
|
The function to apply to the batch. |
required |
mesh
|
Mesh
|
The device mesh to use. |
required |
in_spec
|
LogicalAxisSpec | PartitionSpec
|
The input sharding specification. |
required |
out_spec
|
LogicalAxisSpec | PartitionSpec | None
|
Optional output sharding specification. |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
The transformed batch. |
create_sharded_param ¤
create_sharded_param(init_fn: Callable, shape: tuple[int, ...], logical_spec: LogicalAxisSpec | PartitionSpec) -> Param
Create a parameter with explicit sharding annotation.
This utility method makes it easier to create model parameters with appropriate sharding annotations for distributed training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
init_fn
|
Callable
|
The initialization function for the parameter. |
required |
shape
|
tuple[int, ...]
|
The shape of the parameter. |
required |
logical_spec
|
LogicalAxisSpec | PartitionSpec
|
The logical sharding specification. |
required |
Returns:
| Type | Description |
|---|---|
Param
|
An initialized parameter with sharding annotation. |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
set_state ¤
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
get_partition_spec ¤
get_partition_spec(logical_spec: LogicalAxisSpec | PartitionSpec) -> PartitionSpec
Convert a logical axis specification to a physical PartitionSpec.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logical_spec
|
LogicalAxisSpec | PartitionSpec
|
A tuple of logical axis names or a PartitionSpec. |
required |
Returns:
| Type | Description |
|---|---|
PartitionSpec
|
A PartitionSpec using physical device mesh axis names. |
get_named_sharding ¤
get_named_sharding(mesh: Mesh, logical_spec: LogicalAxisSpec | PartitionSpec) -> NamedSharding
Create a NamedSharding from a logical axis specification and mesh.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mesh
|
Mesh
|
The device mesh to use for sharding. |
required |
logical_spec
|
LogicalAxisSpec | PartitionSpec
|
A tuple of logical axis names or a PartitionSpec. |
required |
Returns:
| Type | Description |
|---|---|
NamedSharding
|
A NamedSharding using the provided mesh and converted partition spec. |
parallel_transform ¤
parallel_transform(batch: Batch, transform_fn: Callable[[Batch], Batch], mesh: Mesh, in_spec: LogicalAxisSpec | PartitionSpec, out_spec: LogicalAxisSpec | PartitionSpec | None = None) -> Batch
Apply a transformation to each shard of the batch in parallel.
This uses nnx.shard_map to efficiently process data in parallel across multiple devices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
The batch of data to transform. |
required |
transform_fn
|
Callable[[Batch], Batch]
|
A function that takes a batch as input and returns a transformed batch. |
required |
mesh
|
Mesh
|
The device mesh to use for sharding. |
required |
in_spec
|
LogicalAxisSpec | PartitionSpec
|
The input sharding specification. |
required |
out_spec
|
LogicalAxisSpec | PartitionSpec | None
|
Optional output sharding specification. If not provided, the input spec is used. |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
The transformed batch with the specified sharding. |