Skip to content

Data Parallel¤

Utilities for data-parallel training across devices.

See Also¤


datarax.distributed.data_parallel ¤

Data parallelism utilities for Datarax.

This module provides functions for data-parallel training in JAX models, centered on current SPMD APIs via nnx.jit and meshes.

logger module-attribute ¤

logger = getLogger(__name__)

create_data_parallel_sharding ¤

create_data_parallel_sharding(mesh: Mesh, data_axis: str = 'data') -> Sharding

Create a Sharding object for data-parallel training.

Parameters:

Name Type Description Default
mesh Mesh

The device mesh to use for sharding.

required
data_axis str

The name of the mesh axis to use for data parallelism.

'data'

Returns:

Type Description
Sharding

A JAX Sharding object for data-parallel training.

place_batch_on_shards ¤

place_batch_on_shards(batch: Batch, sharding: Sharding) -> Batch

Shard a batch of data across devices.

Parameters:

Name Type Description Default
batch Batch

The batch to shard.

required
sharding Sharding

The sharding specification to use.

required

Returns:

Type Description
Batch

The sharded batch.

spmd_train_step ¤

spmd_train_step(model: Module, optimizer: Optimizer, loss_fn: Callable[[Module, Batch], Array], batch: Batch) -> Array

Execute a data-parallel training step using SPMD.

Uses nnx.value_and_grad for automatic differentiation. The XLA compiler handles gradient AllReduce automatically when model parameters are sharded across devices via jax.set_mesh or explicit NamedSharding.

This function should be called inside an @nnx.jit decorated function with a mesh context active.

Parameters:

Name Type Description Default
model Module

The NNX model to train.

required
optimizer Optimizer

The NNX optimizer wrapping the model.

required
loss_fn Callable[[Module, Batch], Array]

Function (model, batch) -> loss scalar.

required
batch Batch

The training batch (should be pre-sharded).

required

Returns:

Type Description
Array

The loss value.

Example

mesh = jax.make_mesh((4,), ("data",)) rules = data_parallel_rules()

@nnx.jit def train_step(model, optimizer, batch): return spmd_train_step(model, optimizer, my_loss_fn, batch)

with jax.set_mesh(mesh): loss = train_step(model, optimizer, batch)

data_parallel_train_step ¤

data_parallel_train_step(loss_fn: Callable, optimizer: Any, batch: Batch, state: Any, rngs: Any | None = None) -> tuple[Any, dict[str, Any]]

Execute a data-parallel training step using pmap.

Uses jax.pmap with lax.pmean for gradient averaging. For modern SPMD training, prefer spmd_train_step instead.

Parameters:

Name Type Description Default
loss_fn Callable

Function that computes the loss value.

required
optimizer Any

The optimizer to use.

required
batch Batch

The batch to process.

required
state Any

The current training state.

required
rngs Any | None

Optional random number generator keys.

None

Returns:

Type Description
tuple[Any, dict[str, Any]]

A tuple of (new_state, metrics).

place_model_state_on_shards ¤

place_model_state_on_shards(state: Any, mesh: Mesh, param_sharding: str | dict[str, PartitionSpec] | None = None) -> Any

Shard a model's state across devices.

Parameters:

Name Type Description Default
state Any

The model state to shard.

required
mesh Mesh

The device mesh to shard across.

required
param_sharding str | dict[str, PartitionSpec] | None

Optional parameter sharding specifications. Use "replicate" to replicate all parameters, or a dict mapping parameter paths to PartitionSpec for per-parameter sharding.

None

Returns:

Type Description
Any

The sharded model state.

place_nnx_state_on_shards ¤

place_nnx_state_on_shards(state: State, mesh: Mesh, filter_sharding: StateSharding | dict[Any, PartitionSpec | Sharding]) -> State

Shard a Flax NNX state tree using current NNX sharding helpers.

reduce_gradients_across_devices ¤

reduce_gradients_across_devices(gradients: Any, reduce_type: str = 'mean', axis_name: str = 'batch') -> Any

All-reduce gradients across devices using collective operations.

Only valid inside a pmap or shard_map context. For SPMD training with nnx.jit, gradient reduction is handled automatically by the compiler.

Parameters:

Name Type Description Default
gradients Any

The gradients to reduce.

required
reduce_type str

The type of reduction ("mean" or "sum").

'mean'
axis_name str

The axis name for the collective operation.

'batch'

Returns:

Type Description
Any

The reduced gradients.

Raises:

Type Description
ValueError

If reduce_type is not "mean" or "sum".

reduce_gradient_tree ¤

reduce_gradient_tree(gradients: Any, reduce_type: str = 'mean') -> Any

Reduce gradients using standard JAX operations on global arrays.

Works in SPMD contexts (inside nnx.jit with mesh). The XLA compiler handles cross-device communication automatically.

Parameters:

Name Type Description Default
gradients Any

The gradients to reduce (global sharded arrays).

required
reduce_type str

The type of reduction ("mean" or "sum").

'mean'

Returns:

Type Description
Any

The reduced gradients.

Raises:

Type Description
ValueError

If reduce_type is not "mean" or "sum".