Skip to content

Device Mesh¤

Configure multi-device mesh layouts for distributed training.

See Also¤


datarax.distributed.device_mesh ¤

Device mesh management for JAX distributed training.

This module provides utilities for creating and managing JAX device meshes for coordinating distributed computations across multiple devices.

logger module-attribute ¤

logger = getLogger(__name__)

DeviceMeshManager ¤

Manager for creating and configuring JAX device meshes.

This class provides utilities for creating device meshes for different distributed training configurations, including data-parallel, model-parallel, and hybrid approaches.

create_device_mesh staticmethod ¤

create_device_mesh(mesh_shape: dict[str, int] | list[tuple[str, int]], devices: list[Any] | None = None) -> Mesh

Create a JAX device mesh with the specified shape.

Parameters:

Name Type Description Default
mesh_shape dict[str, int] | list[tuple[str, int]]

The shape of the mesh, specified either as a dictionary mapping axis names to sizes, or as a list of (name, size) tuples.

required
devices list[Any] | None

Optional list of devices to use. If None, uses all available devices.

None

Returns:

Type Description
Mesh

A JAX device mesh.

Raises:

Type Description
ValueError

If the mesh shape is incompatible with the number of devices.

create_data_parallel_mesh staticmethod ¤

create_data_parallel_mesh(num_devices: int | None = None) -> Mesh

Create a data-parallel device mesh.

Parameters:

Name Type Description Default
num_devices int | None

Optional number of devices to use. If None, uses all available devices.

None

Returns:

Type Description
Mesh

A JAX device mesh configured for data-parallel training.

create_model_parallel_mesh staticmethod ¤

create_model_parallel_mesh(num_devices: int) -> Mesh

Create a model-parallel device mesh.

Parameters:

Name Type Description Default
num_devices int

Number of devices to use for model parallelism.

required

Returns:

Type Description
Mesh

A JAX device mesh configured for model-parallel training.

create_hybrid_mesh staticmethod ¤

create_hybrid_mesh(data_parallel_size: int, model_parallel_size: int) -> Mesh

Create a hybrid data-parallel and model-parallel device mesh.

Parameters:

Name Type Description Default
data_parallel_size int

Number of devices to use for data parallelism.

required
model_parallel_size int

Number of devices to use for model parallelism.

required

Returns:

Type Description
Mesh

A JAX device mesh configured for hybrid parallel training.

Raises:

Type Description
ValueError

If there aren't enough devices available.

get_mesh_info staticmethod ¤

get_mesh_info(mesh: Mesh) -> dict[str, int | dict[str, int]]

Get information about a device mesh.

Parameters:

Name Type Description Default
mesh Mesh

The device mesh to inspect.

required

Returns:

Type Description
dict[str, int | dict[str, int]]

A dictionary containing information about the mesh, including the

dict[str, int | dict[str, int]]

number of devices and the size of each axis.