Device Mesh¤
Configure multi-device mesh layouts for distributed training.
See Also¤
- Distributed Overview - All distributed tools
- Device Placement - Device detection
- Data Parallel - Data parallelism
- Distributed Training Guide
- Sharding Quick Reference
datarax.distributed.device_mesh ¤
Device mesh management for JAX distributed training.
This module provides utilities for creating and managing JAX device meshes for coordinating distributed computations across multiple devices.
DeviceMeshManager ¤
Manager for creating and configuring JAX device meshes.
This class provides utilities for creating device meshes for different distributed training configurations, including data-parallel, model-parallel, and hybrid approaches.
create_device_mesh
staticmethod
¤
create_device_mesh(mesh_shape: dict[str, int] | list[tuple[str, int]], devices: list[Any] | None = None) -> Mesh
Create a JAX device mesh with the specified shape.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mesh_shape
|
dict[str, int] | list[tuple[str, int]]
|
The shape of the mesh, specified either as a dictionary mapping axis names to sizes, or as a list of (name, size) tuples. |
required |
devices
|
list[Any] | None
|
Optional list of devices to use. If None, uses all available devices. |
None
|
Returns:
| Type | Description |
|---|---|
Mesh
|
A JAX device mesh. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the mesh shape is incompatible with the number of devices. |
create_data_parallel_mesh
staticmethod
¤
create_model_parallel_mesh
staticmethod
¤
create_hybrid_mesh
staticmethod
¤
Create a hybrid data-parallel and model-parallel device mesh.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_parallel_size
|
int
|
Number of devices to use for data parallelism. |
required |
model_parallel_size
|
int
|
Number of devices to use for model parallelism. |
required |
Returns:
| Type | Description |
|---|---|
Mesh
|
A JAX device mesh configured for hybrid parallel training. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If there aren't enough devices available. |
get_mesh_info
staticmethod
¤
Get information about a device mesh.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mesh
|
Mesh
|
The device mesh to inspect. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, int | dict[str, int]]
|
A dictionary containing information about the mesh, including the |
dict[str, int | dict[str, int]]
|
number of devices and the size of each axis. |