Skip to content

Prefetcher¤

Asynchronous data prefetching for pipeline optimization.

See Also¤


datarax.control.prefetcher ¤

Prefetcher implementation for Datarax.

This module provides thread-based prefetchers that load data in the background while the main thread processes previously loaded data.

Two-stage pipeline (P2.1): CPU prefetch buffer (size=4) → jax.device_put → device buffer (size=2)

Design follows Grain's prefetch pattern: - Warm start: background loading begins immediately on construction - Sentinel-based StopIteration: uses _END sentinel, not isinstance(Exception) - Clean shutdown: stop_event + thread join with timeout

logger module-attribute ¤

logger = getLogger(__name__)

T module-attribute ¤

T = TypeVar('T')

Prefetcher dataclass ¤

Prefetcher(buffer_size: int = 2)

A prefetcher that uses threads to load data in the background.

Starts loading immediately when prefetch() is called (warm start), not lazily on the first next() call.

buffer_size class-attribute instance-attribute ¤

buffer_size: int = 2

prefetch ¤

prefetch(iterator: Iterator[T]) -> _PrefetchIterator[T]

Prefetch items from the iterator in a background thread.

The background thread starts loading immediately (warm start). Uses a sentinel value for clean end-of-iterator signaling.

Parameters:

Name Type Description Default
iterator Iterator[T]

Iterator to prefetch from.

required

Returns:

Type Description
_PrefetchIterator[T]

A closeable iterator that yields prefetched items.

DevicePrefetcher ¤

DevicePrefetcher(buffer_size: int = 2, device: object | None = None)

Two-stage prefetcher: async host-to-device transfer via jax.device_put.

Follows Grain's experimental/device_put/device_put.py pattern. Compose with Prefetcher for full two-stage pipeline::

raw_iter → Prefetcher(buffer=4) → DevicePrefetcher(buffer=2) → consumer

The consumer receives JAX arrays already on device, overlapping the H2D transfer with computation on the previous batch.

Parameters:

Name Type Description Default
buffer_size int

Number of device-side batches to buffer ahead.

2
device object | None

Optional target device.

None

buffer_size instance-attribute ¤

buffer_size = buffer_size

device instance-attribute ¤

device = device

prefetch ¤

prefetch(iterator: Iterator[T]) -> _DevicePutIterator[T]

Begin async device transfer from the given iterator.

Starts a background thread that calls jax.device_put on each item and buffers the results. The returned iterator yields device-resident JAX arrays.

Parameters:

Name Type Description Default
iterator Iterator[T]

Iterator of CPU-side data (numpy arrays, dicts, pytrees).

required

Returns:

Type Description
_DevicePutIterator[T]

A closeable iterator yielding device-resident data.

start_prefetch ¤

start_prefetch(iterator: Iterator[T]) -> _DevicePutIterator[T]

Begin prefetching immediately (warm-start API, P2.3).

Identical to prefetch() — the background thread starts on construction. Call this during checkpoint recovery or model compilation so the data pipeline warms up while other initialization completes.

Parameters:

Name Type Description Default
iterator Iterator[T]

Iterator of CPU-side data.

required

Returns:

Type Description
_DevicePutIterator[T]

A closeable iterator to consume when ready.

create_prefetch_stream ¤

create_prefetch_stream(iterator: Iterator[T], *, mode: Literal['none', 'grain', 'flax', 'thread'], size: int, device: object | None = None) -> Iterator[T] | object

Prefetch an iterator through the requested upstream-backed adapter.

grain mode delegates to grain.experimental.device_put for Grain datasets, flax mode delegates to flax.jax_utils.prefetch_to_device, and thread mode keeps Datarax's closeable thread wrapper for custom iterator lifecycle behavior.