Data Parallel¤
Utilities for data-parallel training across devices.
See Also¤
- Distributed Overview - All distributed tools
- Device Mesh - Mesh configuration
- Metrics - Distributed metrics
- Distributed Training Guide
- Sharding Quick Reference
datarax.distributed.data_parallel ¤
Data parallelism utilities for Datarax.
This module provides functions for data-parallel training in JAX models,
centered on current SPMD APIs via nnx.jit and meshes.
create_data_parallel_sharding ¤
Create a Sharding object for data-parallel training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mesh
|
Mesh
|
The device mesh to use for sharding. |
required |
data_axis
|
str
|
The name of the mesh axis to use for data parallelism. |
'data'
|
Returns:
| Type | Description |
|---|---|
Sharding
|
A JAX Sharding object for data-parallel training. |
place_batch_on_shards ¤
spmd_train_step ¤
spmd_train_step(model: Module, optimizer: Optimizer, loss_fn: Callable[[Module, Batch], Array], batch: Batch) -> Array
Execute a data-parallel training step using SPMD.
Uses nnx.value_and_grad for automatic differentiation. The XLA compiler handles gradient AllReduce automatically when model parameters are sharded across devices via jax.set_mesh or explicit NamedSharding.
This function should be called inside an @nnx.jit decorated function with a mesh context active.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The NNX model to train. |
required |
optimizer
|
Optimizer
|
The NNX optimizer wrapping the model. |
required |
loss_fn
|
Callable[[Module, Batch], Array]
|
Function (model, batch) -> loss scalar. |
required |
batch
|
Batch
|
The training batch (should be pre-sharded). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The loss value. |
Example
mesh = jax.make_mesh((4,), ("data",)) rules = data_parallel_rules()
@nnx.jit def train_step(model, optimizer, batch): return spmd_train_step(model, optimizer, my_loss_fn, batch)
with jax.set_mesh(mesh): loss = train_step(model, optimizer, batch)
data_parallel_train_step ¤
data_parallel_train_step(loss_fn: Callable, optimizer: Any, batch: Batch, state: Any, rngs: Any | None = None) -> tuple[Any, dict[str, Any]]
Execute a data-parallel training step using pmap.
Uses jax.pmap with lax.pmean for gradient averaging. For modern SPMD training, prefer spmd_train_step instead.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_fn
|
Callable
|
Function that computes the loss value. |
required |
optimizer
|
Any
|
The optimizer to use. |
required |
batch
|
Batch
|
The batch to process. |
required |
state
|
Any
|
The current training state. |
required |
rngs
|
Any | None
|
Optional random number generator keys. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Any, dict[str, Any]]
|
A tuple of (new_state, metrics). |
place_model_state_on_shards ¤
place_model_state_on_shards(state: Any, mesh: Mesh, param_sharding: str | dict[str, PartitionSpec] | None = None) -> Any
Shard a model's state across devices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Any
|
The model state to shard. |
required |
mesh
|
Mesh
|
The device mesh to shard across. |
required |
param_sharding
|
str | dict[str, PartitionSpec] | None
|
Optional parameter sharding specifications. Use "replicate" to replicate all parameters, or a dict mapping parameter paths to PartitionSpec for per-parameter sharding. |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
The sharded model state. |
place_nnx_state_on_shards ¤
place_nnx_state_on_shards(state: State, mesh: Mesh, filter_sharding: StateSharding | dict[Any, PartitionSpec | Sharding]) -> State
Shard a Flax NNX state tree using current NNX sharding helpers.
reduce_gradients_across_devices ¤
reduce_gradients_across_devices(gradients: Any, reduce_type: str = 'mean', axis_name: str = 'batch') -> Any
All-reduce gradients across devices using collective operations.
Only valid inside a pmap or shard_map context. For SPMD training with nnx.jit, gradient reduction is handled automatically by the compiler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
gradients
|
Any
|
The gradients to reduce. |
required |
reduce_type
|
str
|
The type of reduction ("mean" or "sum"). |
'mean'
|
axis_name
|
str
|
The axis name for the collective operation. |
'batch'
|
Returns:
| Type | Description |
|---|---|
Any
|
The reduced gradients. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If reduce_type is not "mean" or "sum". |
reduce_gradient_tree ¤
Reduce gradients using standard JAX operations on global arrays.
Works in SPMD contexts (inside nnx.jit with mesh). The XLA compiler handles cross-device communication automatically.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
gradients
|
Any
|
The gradients to reduce (global sharded arrays). |
required |
reduce_type
|
str
|
The type of reduction ("mean" or "sum"). |
'mean'
|
Returns:
| Type | Description |
|---|---|
Any
|
The reduced gradients. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If reduce_type is not "mean" or "sum". |