Skip to content

Device Placement¤

datarax.distributed.device_placement ¤

Device placement utilities for JAX distributed training.

This module provides utilities for explicit device placement of JAX arrays and PyTrees, enabling efficient data distribution across accelerators.

Key Features:

  • Explicit device placement with jax.device_put
  • Hardware-aware batch size recommendations
  • PyTree-aware device placement
  • Prefetching utilities for overlapping compute and data transfer
Note

Performance Guidelines (per JAX guide):

  • TPU v5e: Critical batch size >= 240 for optimal throughput
  • H100 GPU: Critical batch size >= 298 for optimal throughput
  • Always use explicit device placement for data pipeline outputs
  • Prefetch to device memory to overlap data transfer with compute
Example
from datarax.distributed.device_placement import DevicePlacement
placement = DevicePlacement()
data = jnp.ones((256, 224, 224, 3))
placed = placement.place_on_device(data, jax.devices()[0])  # Place on device
mesh = Mesh(np.array(jax.devices()), axis_names=("data",))
sharding = NamedSharding(mesh, PartitionSpec("data", None, None, None))
distributed = placement.distribute_batch(data, sharding)  # Distribute batch

logger module-attribute ¤

logger = getLogger(__name__)

PyTree module-attribute ¤

PyTree = Any

HardwareType ¤

Bases: Enum

Enumeration of supported hardware types.

TPU_V5E class-attribute instance-attribute ¤

TPU_V5E = 'tpu_v5e'

TPU_V5P class-attribute instance-attribute ¤

TPU_V5P = 'tpu_v5p'

TPU_V4 class-attribute instance-attribute ¤

TPU_V4 = 'tpu_v4'

H100 class-attribute instance-attribute ¤

H100 = 'h100'

A100 class-attribute instance-attribute ¤

A100 = 'a100'

V100 class-attribute instance-attribute ¤

V100 = 'v100'

CPU class-attribute instance-attribute ¤

CPU = 'cpu'

UNKNOWN class-attribute instance-attribute ¤

UNKNOWN = 'unknown'

BatchSizeRecommendation dataclass ¤

BatchSizeRecommendation(min_batch_size: int, optimal_batch_size: int, critical_batch_size: int, max_memory_batch_size: int | None = None, notes: str = '')

Hardware-specific batch size recommendations.

Attributes:

Name Type Description
min_batch_size int

Minimum batch size for reasonable efficiency.

optimal_batch_size int

Optimal batch size for peak throughput.

critical_batch_size int

Critical batch size for reaching roofline (per JAX guide).

max_memory_batch_size int | None

Maximum batch size before OOM (estimate).

notes str

Additional notes about the recommendation.

min_batch_size instance-attribute ¤

min_batch_size: int

optimal_batch_size instance-attribute ¤

optimal_batch_size: int

critical_batch_size instance-attribute ¤

critical_batch_size: int

max_memory_batch_size class-attribute instance-attribute ¤

max_memory_batch_size: int | None = None

notes class-attribute instance-attribute ¤

notes: str = ''

DevicePlacement ¤

DevicePlacement(default_device: Device | None = None)

Utility class for explicit device placement of JAX arrays.

This class provides methods for placing arrays on specific devices, distributing batches across devices using sharding, and providing hardware-aware batch size recommendations.

Example
placement = DevicePlacement()
data = jnp.ones((4, 8))
gpu_data = placement.place_on_device(data, jax.devices("gpu")[0])  # Place on first GPU
rec = placement.get_batch_size_recommendation()  # Get batch size recommendation
print(f"Optimal batch: {rec.optimal_batch_size}")

Parameters:

Name Type Description Default
default_device Device | None

Default device to use when none is specified. If None, lazily resolves to jax.devices()[0] on first access.

None

default_device property ¤

default_device: Device

Get the default device, lazily resolving on first access.

hardware_type property ¤

hardware_type: HardwareType

Get the detected hardware type, lazily resolving on first access.

num_devices property ¤

num_devices: int

Get the number of available devices.

place_on_device ¤

place_on_device(data: PyTree, device: Device | None = None) -> PyTree

Place data on a specific device.

This uses jax.device_put for explicit device placement, ensuring data is transferred to the target device.

Parameters:

Name Type Description Default
data PyTree

PyTree of JAX arrays to place on device.

required
device Device | None

Target device. If None, uses the default device.

None

Returns:

Type Description
PyTree

PyTree with arrays placed on the specified device.

Example
data = {"images": jnp.ones((4, 28, 28, 3))}
gpu_data = placement.place_on_device(data, jax.devices("gpu")[0])

distribute_batch ¤

distribute_batch(data: PyTree, sharding: Sharding) -> PyTree

Distribute data across devices using the specified sharding.

This applies explicit device placement using jax.device_put with a Sharding object, distributing the data across multiple devices.

Parameters:

Name Type Description Default
data PyTree

PyTree of JAX arrays to distribute.

required
sharding Sharding

JAX Sharding specification.

required

Returns:

Type Description
PyTree

PyTree with arrays distributed according to the sharding.

Example
mesh = Mesh(np.array(jax.devices()), ("data",))
sharding = NamedSharding(mesh, PartitionSpec("data", None))
distributed = placement.distribute_batch(data, sharding)

replicate_across_devices ¤

replicate_across_devices(data: PyTree, devices: list[Device] | None = None) -> PyTree

Replicate data across all specified devices.

Creates a copy of the data on each device, useful for broadcasting model weights or constants.

Parameters:

Name Type Description Default
data PyTree

PyTree of JAX arrays to replicate.

required
devices list[Device] | None

List of devices to replicate to. If None, uses all devices.

None

Returns:

Type Description
PyTree

PyTree with arrays replicated across devices.

shard_batch_dim ¤

shard_batch_dim(data: PyTree, mesh: Mesh, batch_axis: int = 0, mesh_axis: str = 'data') -> PyTree

Shard data along the batch dimension.

This is the most common sharding pattern for data-parallel training, where each device processes a slice of the batch.

Parameters:

Name Type Description Default
data PyTree

PyTree of JAX arrays to shard.

required
mesh Mesh

Device mesh to shard across.

required
batch_axis int

The axis index representing the batch dimension.

0
mesh_axis str

The mesh axis name to shard along.

'data'

Returns:

Type Description
PyTree

PyTree with arrays sharded along the batch dimension.

prefetch_to_device ¤

prefetch_to_device(data_iterator: Any, device: Device | None = None, buffer_size: int = 2, cpu_buffer_size: int | None = None) -> Any

Create a prefetching wrapper for host-to-device transfer.

Generic Python iterators use Datarax's closeable device-put thread wrapper. Grain-backed datasets should use datarax.control.prefetcher.create_prefetch_stream(..., mode="grain") directly when they need Grain's dataset-level device_put behavior.

Parameters:

Name Type Description Default
data_iterator Any

Iterator yielding PyTrees of data.

required
device Device | None

Target device for prefetching.

None
buffer_size int

Device buffer size (Stage 2). Default is 2.

2
cpu_buffer_size int | None

CPU buffer size (Stage 1). Default is buffer_size * 2.

None

Returns:

Type Description
Any

Iterator that yields device-placed data.

Note

Throughput depends on workload, host/device balance, and hardware. Add benchmark artifacts before making numeric performance claims.

get_batch_size_recommendation ¤

get_batch_size_recommendation(hardware_type: HardwareType | None = None) -> BatchSizeRecommendation

Get batch size recommendation for the current hardware.

Parameters:

Name Type Description Default
hardware_type HardwareType | None

Override hardware type. If None, uses detected type.

None

Returns:

Type Description
BatchSizeRecommendation

BatchSizeRecommendation with hardware-specific values.

validate_batch_size ¤

validate_batch_size(batch_size: int, warn_suboptimal: bool = True) -> tuple[bool, str]

Validate batch size against hardware recommendations.

Parameters:

Name Type Description Default
batch_size int

The batch size to validate.

required
warn_suboptimal bool

Whether to warn for suboptimal (but valid) sizes.

True

Returns:

Type Description
tuple[bool, str]

Tuple of (is_valid, message).

get_device_info ¤

get_device_info() -> dict[str, Any]

Get information about available devices.

Returns:

Type Description
dict[str, Any]

Dictionary containing device information.

place_on_device ¤

place_on_device(data: PyTree, device: Device | None = None) -> PyTree

Convenience function for placing data on a device.

Parameters:

Name Type Description Default
data PyTree

PyTree of JAX arrays.

required
device Device | None

Target device. If None, uses first available device.

None

Returns:

Type Description
PyTree

PyTree with arrays on the specified device.

distribute_batch ¤

distribute_batch(data: PyTree, sharding: Sharding) -> PyTree

Convenience function for distributing data across devices.

Parameters:

Name Type Description Default
data PyTree

PyTree of JAX arrays.

required
sharding Sharding

JAX Sharding specification.

required

Returns:

Type Description
PyTree

PyTree with arrays distributed according to sharding.

get_batch_size_recommendation ¤

get_batch_size_recommendation(hardware_type: HardwareType | None = None) -> BatchSizeRecommendation

Get batch size recommendation for current or specified hardware.

Parameters:

Name Type Description Default
hardware_type HardwareType | None

Hardware type to get recommendation for.

None

Returns:

Type Description
BatchSizeRecommendation

BatchSizeRecommendation with hardware-specific values.

prefetch_to_device ¤

prefetch_to_device(data_iterator: Any, size: int = 2, device: Device | None = None, cpu_buffer_size: int | None = None) -> Any

Prefetch iterator outputs to device memory.

Parameters:

Name Type Description Default
data_iterator Any

Iterator yielding PyTrees of data (e.g., from a pipeline).

required
size int

Device buffer size (Stage 2). Default is 2.

2
device Device | None

Target device for prefetching. If None, uses the default device.

None
cpu_buffer_size int | None

CPU buffer size (Stage 1). Default is size * 2.

None

Returns:

Type Description
Any

Iterator that yields device-placed data with two-stage prefetching.

Example
from flax import nnx

from datarax import Pipeline, prefetch_to_device

pipeline = Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0))
prefetched = prefetch_to_device(pipeline, size=3)

for batch in prefetched:
    # batch is already on device, ready for computation
    train_step(batch)
Note

Numeric throughput claims must be backed by benchmark artifacts for the target workload and hardware.

Overview¤

The device_placement module provides utilities for explicit device placement of JAX arrays and PyTrees, enabling efficient data distribution across accelerators. It includes hardware-aware batch size recommendations based on JAX performance guidelines.

Key Components¤

HardwareType¤

Enumeration of supported hardware types for batch size recommendations:

from datarax.distributed.device_placement import HardwareType

# Available hardware types:
# - TPU_V5E, TPU_V5P, TPU_V4
# - H100, A100, V100 (GPUs)
# - CPU
# - UNKNOWN (conservative defaults)

BatchSizeRecommendation¤

Dataclass containing hardware-specific batch size recommendations:

from datarax.distributed.device_placement import BatchSizeRecommendation

# Fields:
# - min_batch_size: Minimum for reasonable efficiency
# - optimal_batch_size: For peak throughput
# - critical_batch_size: For reaching roofline performance
# - max_memory_batch_size: Before OOM (estimate)
# - notes: Additional guidance

DevicePlacement Class¤

Main utility class for device placement operations.

Usage Examples¤

Basic Device Placement¤

import jax
import jax.numpy as jnp
from datarax.distributed.device_placement import DevicePlacement

# Create placement utility
placement = DevicePlacement()

# Place data on specific device
data = jnp.ones((256, 224, 224, 3))
placed = placement.place_on_device(data, jax.devices()[0])

# Check detected hardware
print(f"Hardware type: {placement.hardware_type}")
print(f"Number of devices: {placement.num_devices}")

Getting Batch Size Recommendations¤

from datarax.distributed.device_placement import (
    DevicePlacement,
    get_batch_size_recommendation,
    HardwareType
)

# Auto-detect hardware and get recommendations
placement = DevicePlacement()
rec = placement.get_batch_size_recommendation()

print(f"Minimum batch size: {rec.min_batch_size}")
print(f"Optimal batch size: {rec.optimal_batch_size}")
print(f"Critical batch size: {rec.critical_batch_size}")
print(f"Notes: {rec.notes}")

# Or use the convenience function
rec = get_batch_size_recommendation()

# Get for specific hardware
h100_rec = get_batch_size_recommendation(HardwareType.H100)
print(f"H100 optimal batch: {h100_rec.optimal_batch_size}")  # 320

Validating Batch Size¤

from datarax.distributed.device_placement import DevicePlacement

placement = DevicePlacement()

# Validate a batch size
is_valid, message = placement.validate_batch_size(64)
print(message)

# Validate without suboptimal warnings
is_valid, message = placement.validate_batch_size(64, warn_suboptimal=False)

Distributing Data Across Devices¤

import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from datarax.distributed.device_placement import DevicePlacement

placement = DevicePlacement()

# Create a device mesh
devices = np.array(jax.devices()).reshape(-1)
mesh = Mesh(devices, axis_names=("data",))

# Create sharding specification
sharding = NamedSharding(mesh, PartitionSpec("data", None, None, None))

# Distribute batch across devices
data = jnp.ones((8, 28, 28, 3))
distributed = placement.distribute_batch(data, sharding)

Sharding Along Batch Dimension¤

import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from datarax.distributed.device_placement import DevicePlacement

placement = DevicePlacement()

# Create mesh
devices = np.array(jax.devices()).reshape(-1)
mesh = Mesh(devices, axis_names=("data",))

# Shard PyTree along batch dimension
data = {
    "images": jnp.ones((8, 28, 28, 3)),
    "labels": jnp.ones((8,), dtype=jnp.int32)
}

sharded = placement.shard_batch_dim(data, mesh, batch_axis=0, mesh_axis="data")

Replicating Data Across Devices¤

import jax
import jax.numpy as jnp
from datarax.distributed.device_placement import DevicePlacement

placement = DevicePlacement()

# Replicate model weights across all devices
weights = {"w": jnp.ones((128, 64)), "b": jnp.zeros(64)}
replicated = placement.replicate_across_devices(weights)

Prefetching to Device¤

from datarax.distributed.device_placement import DevicePlacement

placement = DevicePlacement()

def data_generator():
    for i in range(100):
        yield {"batch": i}

# Create prefetching iterator
prefetched = placement.prefetch_to_device(
    data_generator(),
    buffer_size=2  # Prefetch 2 batches ahead
)

for batch in prefetched:
    # Process batch (already on device)
    pass

Getting Device Information¤

from datarax.distributed.device_placement import DevicePlacement

placement = DevicePlacement()
info = placement.get_device_info()

print(f"Number of devices: {info['num_devices']}")
print(f"Hardware type: {info['hardware_type']}")
print(f"Platforms: {info['platforms']}")

Hardware-Specific Recommendations¤

Based on JAX performance guidelines:

Hardware Min Batch Optimal Critical Notes
TPU v5e 64 256 240 Critical for roofline
TPU v5p 128 512 480 Higher throughput variant
TPU v4 64 256 192 Similar to v5e
H100 64 320 298 Critical for roofline
A100 32 256 240 80GB variant
V100 16 128 96 Memory-limited
CPU 1 32 16 Bandwidth-bound

Integration with Datarax Pipelines¤

from datarax.pipeline import Pipeline
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.distributed.device_placement import DevicePlacement, get_batch_size_recommendation
import jax.numpy as jnp

# Get recommended batch size
rec = get_batch_size_recommendation()
batch_size = rec.optimal_batch_size

# Create pipeline with optimal batch size
data = [{"image": jnp.ones((28, 28, 3))} for _ in range(1000)]
source = MemorySource(MemorySourceConfig(), data)
pipeline = Pipeline(source=source, stages=[], batch_size=batch_size, rngs=nnx.Rngs(0))

# Use device placement for explicit placement
placement = DevicePlacement()

for batch in pipeline:
    # Place batch on device explicitly
    placed_batch = placement.place_on_device(batch)
    # Process placed batch...
    break

Best Practices¤

  1. Use critical batch size: For maximum throughput, aim for at least the critical batch size for your hardware.

  2. Validate early: Use validate_batch_size() during pipeline setup to catch suboptimal configurations.

  3. Explicit placement: Use place_on_device() or distribute_batch() for explicit device placement of pipeline outputs.

  4. Prefetch for overlap: Use prefetch_to_device() to overlap data transfer with computation.

  5. Check hardware detection: Use get_device_info() to verify correct hardware detection.

Convenience Functions¤

The module also provides standalone functions:

from datarax.distributed.device_placement import (
    place_on_device,        # Place data on device
    distribute_batch,       # Distribute with sharding
    get_batch_size_recommendation  # Get recommendations
)