Distributed Metrics¤
Aggregate metrics across devices and hosts.
See Also¤
- Distributed Overview - All distributed tools
- Data Parallel - Data parallelism
- Monitoring Metrics - Metric tracking
- Distributed Training Guide
datarax.distributed.metrics ¤
Distributed metrics collection utilities for Datarax.
This module provides functions for collecting and aggregating metrics across multiple devices in distributed training settings.
Two API variants are provided:
-
Default functions (reduce_mean, reduce_sum, etc.): Use standard JAX operations (jnp.mean, jnp.sum) on global arrays. Work in SPMD contexts with nnx.jit + mesh.
-
Collective functions (reduce_mean_collective, etc.): Use JAX collective operations (lax.pmean, lax.psum). Only valid inside pmap or shard_map contexts.
reduce_mean ¤
reduce_sum ¤
reduce_max ¤
reduce_min ¤
reduce_custom ¤
reduce_custom(metrics: dict[str, Any], reduce_fn: dict[str, str | None] | None = None) -> dict[str, Any]
Apply custom reduction operations to metrics.
Uses standard JAX operations. Works in SPMD contexts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, Any]
|
The metrics to reduce. |
required |
reduce_fn
|
dict[str, str | None] | None
|
A dictionary mapping metric names to reduction operations. Each operation should be one of {"mean", "sum", "max", "min"}. If None, defaults to "mean" for all metrics. |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary of reduced metrics. |
reduce_mean_collective ¤
Compute the mean of metrics using collective operations.
Only valid inside a pmap or shard_map context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, Any]
|
The metrics to reduce. |
required |
axis_name
|
str
|
The name of the axis to reduce across. |
'batch'
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary of mean metrics. |
reduce_sum_collective ¤
Compute the sum of metrics using collective operations.
Only valid inside a pmap or shard_map context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, Any]
|
The metrics to reduce. |
required |
axis_name
|
str
|
The name of the axis to reduce across. |
'batch'
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary of summed metrics. |
all_gather ¤
Gather metrics from all devices.
Only valid inside a pmap or shard_map context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, Any]
|
The metrics to gather. |
required |
axis_name
|
str
|
The name of the axis to gather across. |
'batch'
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary of gathered metrics. |
collect_from_devices ¤
Collect metrics from all devices.
Call outside of a pmapped function to split per-device values from the leading device axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, Any]
|
The metrics from all devices, with the first dimension corresponding to the device axis. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, list[Any] | Any]
|
A dictionary of metrics, with array values split into per-device lists. |