Skip to content

Distributed Metrics¤

Aggregate metrics across devices and hosts.

See Also¤


datarax.distributed.metrics ¤

Distributed metrics collection utilities for Datarax.

This module provides functions for collecting and aggregating metrics across multiple devices in distributed training settings.

Two API variants are provided:

  • Default functions (reduce_mean, reduce_sum, etc.): Use standard JAX operations (jnp.mean, jnp.sum) on global arrays. Work in SPMD contexts with nnx.jit + mesh.

  • Collective functions (reduce_mean_collective, etc.): Use JAX collective operations (lax.pmean, lax.psum). Only valid inside pmap or shard_map contexts.

logger module-attribute ¤

logger = getLogger(__name__)

reduce_mean ¤

reduce_mean(metrics: dict[str, Any]) -> dict[str, Any]

Compute the mean of metrics using standard JAX operations.

Works with global arrays in SPMD contexts (nnx.jit + mesh).

Parameters:

Name Type Description Default
metrics dict[str, Any]

The metrics to reduce.

required

Returns:

Type Description
dict[str, Any]

A dictionary of mean-reduced metrics.

reduce_sum ¤

reduce_sum(metrics: dict[str, Any]) -> dict[str, Any]

Compute the sum of metrics using standard JAX operations.

Works with global arrays in SPMD contexts (nnx.jit + mesh).

Parameters:

Name Type Description Default
metrics dict[str, Any]

The metrics to reduce.

required

Returns:

Type Description
dict[str, Any]

A dictionary of summed metrics.

reduce_max ¤

reduce_max(metrics: dict[str, Any]) -> dict[str, Any]

Compute the maximum of metrics using standard JAX operations.

Works with global arrays in SPMD contexts (nnx.jit + mesh).

Parameters:

Name Type Description Default
metrics dict[str, Any]

The metrics to reduce.

required

Returns:

Type Description
dict[str, Any]

A dictionary of maximum metrics.

reduce_min ¤

reduce_min(metrics: dict[str, Any]) -> dict[str, Any]

Compute the minimum of metrics using standard JAX operations.

Works with global arrays in SPMD contexts (nnx.jit + mesh).

Parameters:

Name Type Description Default
metrics dict[str, Any]

The metrics to reduce.

required

Returns:

Type Description
dict[str, Any]

A dictionary of minimum metrics.

reduce_custom ¤

reduce_custom(metrics: dict[str, Any], reduce_fn: dict[str, str | None] | None = None) -> dict[str, Any]

Apply custom reduction operations to metrics.

Uses standard JAX operations. Works in SPMD contexts.

Parameters:

Name Type Description Default
metrics dict[str, Any]

The metrics to reduce.

required
reduce_fn dict[str, str | None] | None

A dictionary mapping metric names to reduction operations. Each operation should be one of {"mean", "sum", "max", "min"}. If None, defaults to "mean" for all metrics.

None

Returns:

Type Description
dict[str, Any]

A dictionary of reduced metrics.

reduce_mean_collective ¤

reduce_mean_collective(metrics: dict[str, Any], axis_name: str = 'batch') -> dict[str, Any]

Compute the mean of metrics using collective operations.

Only valid inside a pmap or shard_map context.

Parameters:

Name Type Description Default
metrics dict[str, Any]

The metrics to reduce.

required
axis_name str

The name of the axis to reduce across.

'batch'

Returns:

Type Description
dict[str, Any]

A dictionary of mean metrics.

reduce_sum_collective ¤

reduce_sum_collective(metrics: dict[str, Any], axis_name: str = 'batch') -> dict[str, Any]

Compute the sum of metrics using collective operations.

Only valid inside a pmap or shard_map context.

Parameters:

Name Type Description Default
metrics dict[str, Any]

The metrics to reduce.

required
axis_name str

The name of the axis to reduce across.

'batch'

Returns:

Type Description
dict[str, Any]

A dictionary of summed metrics.

all_gather ¤

all_gather(metrics: dict[str, Any], axis_name: str = 'batch') -> dict[str, Any]

Gather metrics from all devices.

Only valid inside a pmap or shard_map context.

Parameters:

Name Type Description Default
metrics dict[str, Any]

The metrics to gather.

required
axis_name str

The name of the axis to gather across.

'batch'

Returns:

Type Description
dict[str, Any]

A dictionary of gathered metrics.

collect_from_devices ¤

collect_from_devices(metrics: dict[str, Any]) -> dict[str, list[Any] | Any]

Collect metrics from all devices.

Call outside of a pmapped function to split per-device values from the leading device axis.

Parameters:

Name Type Description Default
metrics dict[str, Any]

The metrics from all devices, with the first dimension corresponding to the device axis.

required

Returns:

Type Description
dict[str, list[Any] | Any]

A dictionary of metrics, with array values split into per-device lists.