Device Placement¤
datarax.distributed.device_placement ¤
Device placement utilities for JAX distributed training.
This module provides utilities for explicit device placement of JAX arrays and PyTrees, enabling efficient data distribution across accelerators.
Key Features:
- Explicit device placement with jax.device_put
- Hardware-aware batch size recommendations
- PyTree-aware device placement
- Prefetching utilities for overlapping compute and data transfer
Note
Performance Guidelines (per JAX guide):
- TPU v5e: Critical batch size >= 240 for optimal throughput
- H100 GPU: Critical batch size >= 298 for optimal throughput
- Always use explicit device placement for data pipeline outputs
- Prefetch to device memory to overlap data transfer with compute
Example
from datarax.distributed.device_placement import DevicePlacement
placement = DevicePlacement()
data = jnp.ones((256, 224, 224, 3))
placed = placement.place_on_device(data, jax.devices()[0]) # Place on device
mesh = Mesh(np.array(jax.devices()), axis_names=("data",))
sharding = NamedSharding(mesh, PartitionSpec("data", None, None, None))
distributed = placement.distribute_batch(data, sharding) # Distribute batch
BatchSizeRecommendation
dataclass
¤
BatchSizeRecommendation(min_batch_size: int, optimal_batch_size: int, critical_batch_size: int, max_memory_batch_size: int | None = None, notes: str = '')
Hardware-specific batch size recommendations.
Attributes:
| Name | Type | Description |
|---|---|---|
min_batch_size |
int
|
Minimum batch size for reasonable efficiency. |
optimal_batch_size |
int
|
Optimal batch size for peak throughput. |
critical_batch_size |
int
|
Critical batch size for reaching roofline (per JAX guide). |
max_memory_batch_size |
int | None
|
Maximum batch size before OOM (estimate). |
notes |
str
|
Additional notes about the recommendation. |
DevicePlacement ¤
DevicePlacement(default_device: Device | None = None)
Utility class for explicit device placement of JAX arrays.
This class provides methods for placing arrays on specific devices, distributing batches across devices using sharding, and providing hardware-aware batch size recommendations.
Example
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
default_device
|
Device | None
|
Default device to use when none is specified. If None, lazily resolves to jax.devices()[0] on first access. |
None
|
default_device
property
¤
default_device: Device
Get the default device, lazily resolving on first access.
hardware_type
property
¤
hardware_type: HardwareType
Get the detected hardware type, lazily resolving on first access.
place_on_device ¤
Place data on a specific device.
This uses jax.device_put for explicit device placement, ensuring data is transferred to the target device.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
PyTree of JAX arrays to place on device. |
required |
device
|
Device | None
|
Target device. If None, uses the default device. |
None
|
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree with arrays placed on the specified device. |
distribute_batch ¤
Distribute data across devices using the specified sharding.
This applies explicit device placement using jax.device_put with a Sharding object, distributing the data across multiple devices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
PyTree of JAX arrays to distribute. |
required |
sharding
|
Sharding
|
JAX Sharding specification. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree with arrays distributed according to the sharding. |
replicate_across_devices ¤
Replicate data across all specified devices.
Creates a copy of the data on each device, useful for broadcasting model weights or constants.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
PyTree of JAX arrays to replicate. |
required |
devices
|
list[Device] | None
|
List of devices to replicate to. If None, uses all devices. |
None
|
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree with arrays replicated across devices. |
shard_batch_dim ¤
Shard data along the batch dimension.
This is the most common sharding pattern for data-parallel training, where each device processes a slice of the batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
PyTree of JAX arrays to shard. |
required |
mesh
|
Mesh
|
Device mesh to shard across. |
required |
batch_axis
|
int
|
The axis index representing the batch dimension. |
0
|
mesh_axis
|
str
|
The mesh axis name to shard along. |
'data'
|
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree with arrays sharded along the batch dimension. |
prefetch_to_device ¤
prefetch_to_device(data_iterator: Any, device: Device | None = None, buffer_size: int = 2, cpu_buffer_size: int | None = None) -> Any
Create a prefetching wrapper for host-to-device transfer.
Generic Python iterators use Datarax's closeable device-put thread
wrapper. Grain-backed datasets should use
datarax.control.prefetcher.create_prefetch_stream(..., mode="grain")
directly when they need Grain's dataset-level device_put behavior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_iterator
|
Any
|
Iterator yielding PyTrees of data. |
required |
device
|
Device | None
|
Target device for prefetching. |
None
|
buffer_size
|
int
|
Device buffer size (Stage 2). Default is 2. |
2
|
cpu_buffer_size
|
int | None
|
CPU buffer size (Stage 1). Default is buffer_size * 2. |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
Iterator that yields device-placed data. |
Note
Throughput depends on workload, host/device balance, and hardware. Add benchmark artifacts before making numeric performance claims.
get_batch_size_recommendation ¤
get_batch_size_recommendation(hardware_type: HardwareType | None = None) -> BatchSizeRecommendation
Get batch size recommendation for the current hardware.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hardware_type
|
HardwareType | None
|
Override hardware type. If None, uses detected type. |
None
|
Returns:
| Type | Description |
|---|---|
BatchSizeRecommendation
|
BatchSizeRecommendation with hardware-specific values. |
validate_batch_size ¤
place_on_device ¤
distribute_batch ¤
get_batch_size_recommendation ¤
get_batch_size_recommendation(hardware_type: HardwareType | None = None) -> BatchSizeRecommendation
Get batch size recommendation for current or specified hardware.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hardware_type
|
HardwareType | None
|
Hardware type to get recommendation for. |
None
|
Returns:
| Type | Description |
|---|---|
BatchSizeRecommendation
|
BatchSizeRecommendation with hardware-specific values. |
prefetch_to_device ¤
prefetch_to_device(data_iterator: Any, size: int = 2, device: Device | None = None, cpu_buffer_size: int | None = None) -> Any
Prefetch iterator outputs to device memory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_iterator
|
Any
|
Iterator yielding PyTrees of data (e.g., from a pipeline). |
required |
size
|
int
|
Device buffer size (Stage 2). Default is 2. |
2
|
device
|
Device | None
|
Target device for prefetching. If None, uses the default device. |
None
|
cpu_buffer_size
|
int | None
|
CPU buffer size (Stage 1). Default is size * 2. |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
Iterator that yields device-placed data with two-stage prefetching. |
Example
Note
Numeric throughput claims must be backed by benchmark artifacts for the target workload and hardware.
Overview¤
The device_placement module provides utilities for explicit device placement of JAX arrays and PyTrees, enabling efficient data distribution across accelerators. It includes hardware-aware batch size recommendations based on JAX performance guidelines.
Key Components¤
HardwareType¤
Enumeration of supported hardware types for batch size recommendations:
from datarax.distributed.device_placement import HardwareType
# Available hardware types:
# - TPU_V5E, TPU_V5P, TPU_V4
# - H100, A100, V100 (GPUs)
# - CPU
# - UNKNOWN (conservative defaults)
BatchSizeRecommendation¤
Dataclass containing hardware-specific batch size recommendations:
from datarax.distributed.device_placement import BatchSizeRecommendation
# Fields:
# - min_batch_size: Minimum for reasonable efficiency
# - optimal_batch_size: For peak throughput
# - critical_batch_size: For reaching roofline performance
# - max_memory_batch_size: Before OOM (estimate)
# - notes: Additional guidance
DevicePlacement Class¤
Main utility class for device placement operations.
Usage Examples¤
Basic Device Placement¤
import jax
import jax.numpy as jnp
from datarax.distributed.device_placement import DevicePlacement
# Create placement utility
placement = DevicePlacement()
# Place data on specific device
data = jnp.ones((256, 224, 224, 3))
placed = placement.place_on_device(data, jax.devices()[0])
# Check detected hardware
print(f"Hardware type: {placement.hardware_type}")
print(f"Number of devices: {placement.num_devices}")
Getting Batch Size Recommendations¤
from datarax.distributed.device_placement import (
DevicePlacement,
get_batch_size_recommendation,
HardwareType
)
# Auto-detect hardware and get recommendations
placement = DevicePlacement()
rec = placement.get_batch_size_recommendation()
print(f"Minimum batch size: {rec.min_batch_size}")
print(f"Optimal batch size: {rec.optimal_batch_size}")
print(f"Critical batch size: {rec.critical_batch_size}")
print(f"Notes: {rec.notes}")
# Or use the convenience function
rec = get_batch_size_recommendation()
# Get for specific hardware
h100_rec = get_batch_size_recommendation(HardwareType.H100)
print(f"H100 optimal batch: {h100_rec.optimal_batch_size}") # 320
Validating Batch Size¤
from datarax.distributed.device_placement import DevicePlacement
placement = DevicePlacement()
# Validate a batch size
is_valid, message = placement.validate_batch_size(64)
print(message)
# Validate without suboptimal warnings
is_valid, message = placement.validate_batch_size(64, warn_suboptimal=False)
Distributing Data Across Devices¤
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from datarax.distributed.device_placement import DevicePlacement
placement = DevicePlacement()
# Create a device mesh
devices = np.array(jax.devices()).reshape(-1)
mesh = Mesh(devices, axis_names=("data",))
# Create sharding specification
sharding = NamedSharding(mesh, PartitionSpec("data", None, None, None))
# Distribute batch across devices
data = jnp.ones((8, 28, 28, 3))
distributed = placement.distribute_batch(data, sharding)
Sharding Along Batch Dimension¤
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from datarax.distributed.device_placement import DevicePlacement
placement = DevicePlacement()
# Create mesh
devices = np.array(jax.devices()).reshape(-1)
mesh = Mesh(devices, axis_names=("data",))
# Shard PyTree along batch dimension
data = {
"images": jnp.ones((8, 28, 28, 3)),
"labels": jnp.ones((8,), dtype=jnp.int32)
}
sharded = placement.shard_batch_dim(data, mesh, batch_axis=0, mesh_axis="data")
Replicating Data Across Devices¤
import jax
import jax.numpy as jnp
from datarax.distributed.device_placement import DevicePlacement
placement = DevicePlacement()
# Replicate model weights across all devices
weights = {"w": jnp.ones((128, 64)), "b": jnp.zeros(64)}
replicated = placement.replicate_across_devices(weights)
Prefetching to Device¤
from datarax.distributed.device_placement import DevicePlacement
placement = DevicePlacement()
def data_generator():
for i in range(100):
yield {"batch": i}
# Create prefetching iterator
prefetched = placement.prefetch_to_device(
data_generator(),
buffer_size=2 # Prefetch 2 batches ahead
)
for batch in prefetched:
# Process batch (already on device)
pass
Getting Device Information¤
from datarax.distributed.device_placement import DevicePlacement
placement = DevicePlacement()
info = placement.get_device_info()
print(f"Number of devices: {info['num_devices']}")
print(f"Hardware type: {info['hardware_type']}")
print(f"Platforms: {info['platforms']}")
Hardware-Specific Recommendations¤
Based on JAX performance guidelines:
| Hardware | Min Batch | Optimal | Critical | Notes |
|---|---|---|---|---|
| TPU v5e | 64 | 256 | 240 | Critical for roofline |
| TPU v5p | 128 | 512 | 480 | Higher throughput variant |
| TPU v4 | 64 | 256 | 192 | Similar to v5e |
| H100 | 64 | 320 | 298 | Critical for roofline |
| A100 | 32 | 256 | 240 | 80GB variant |
| V100 | 16 | 128 | 96 | Memory-limited |
| CPU | 1 | 32 | 16 | Bandwidth-bound |
Integration with Datarax Pipelines¤
from datarax.pipeline import Pipeline
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.distributed.device_placement import DevicePlacement, get_batch_size_recommendation
import jax.numpy as jnp
# Get recommended batch size
rec = get_batch_size_recommendation()
batch_size = rec.optimal_batch_size
# Create pipeline with optimal batch size
data = [{"image": jnp.ones((28, 28, 3))} for _ in range(1000)]
source = MemorySource(MemorySourceConfig(), data)
pipeline = Pipeline(source=source, stages=[], batch_size=batch_size, rngs=nnx.Rngs(0))
# Use device placement for explicit placement
placement = DevicePlacement()
for batch in pipeline:
# Place batch on device explicitly
placed_batch = placement.place_on_device(batch)
# Process placed batch...
break
Best Practices¤
-
Use critical batch size: For maximum throughput, aim for at least the critical batch size for your hardware.
-
Validate early: Use
validate_batch_size()during pipeline setup to catch suboptimal configurations. -
Explicit placement: Use
place_on_device()ordistribute_batch()for explicit device placement of pipeline outputs. -
Prefetch for overlap: Use
prefetch_to_device()to overlap data transfer with computation. -
Check hardware detection: Use
get_device_info()to verify correct hardware detection.
Convenience Functions¤
The module also provides standalone functions: