Prefetcher¤
Asynchronous data prefetching for pipeline optimization.
See Also¤
- Control Overview - Control flow tools
- DAG Executor - Built-in prefetching
- Performance - Optimization
- Benchmarking - Measure improvement
datarax.control.prefetcher ¤
Prefetcher implementation for Datarax.
This module provides thread-based prefetchers that load data in the background while the main thread processes previously loaded data.
Two-stage pipeline (P2.1): CPU prefetch buffer (size=4) → jax.device_put → device buffer (size=2)
Design follows Grain's prefetch pattern: - Warm start: background loading begins immediately on construction - Sentinel-based StopIteration: uses _END sentinel, not isinstance(Exception) - Clean shutdown: stop_event + thread join with timeout
Prefetcher
dataclass
¤
Prefetcher(buffer_size: int = 2)
A prefetcher that uses threads to load data in the background.
Starts loading immediately when prefetch() is called (warm start), not lazily on the first next() call.
prefetch ¤
Prefetch items from the iterator in a background thread.
The background thread starts loading immediately (warm start). Uses a sentinel value for clean end-of-iterator signaling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
iterator
|
Iterator[T]
|
Iterator to prefetch from. |
required |
Returns:
| Type | Description |
|---|---|
_PrefetchIterator[T]
|
A closeable iterator that yields prefetched items. |
DevicePrefetcher ¤
Two-stage prefetcher: async host-to-device transfer via jax.device_put.
Follows Grain's experimental/device_put/device_put.py pattern.
Compose with Prefetcher for full two-stage pipeline::
raw_iter → Prefetcher(buffer=4) → DevicePrefetcher(buffer=2) → consumer
The consumer receives JAX arrays already on device, overlapping the H2D transfer with computation on the previous batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
buffer_size
|
int
|
Number of device-side batches to buffer ahead. |
2
|
device
|
object | None
|
Optional target device. |
None
|
prefetch ¤
Begin async device transfer from the given iterator.
Starts a background thread that calls jax.device_put on each
item and buffers the results. The returned iterator yields
device-resident JAX arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
iterator
|
Iterator[T]
|
Iterator of CPU-side data (numpy arrays, dicts, pytrees). |
required |
Returns:
| Type | Description |
|---|---|
_DevicePutIterator[T]
|
A closeable iterator yielding device-resident data. |
start_prefetch ¤
Begin prefetching immediately (warm-start API, P2.3).
Identical to prefetch() — the background thread starts on
construction. Call this during checkpoint recovery or model compilation
so the data pipeline warms up while other initialization completes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
iterator
|
Iterator[T]
|
Iterator of CPU-side data. |
required |
Returns:
| Type | Description |
|---|---|
_DevicePutIterator[T]
|
A closeable iterator to consume when ready. |
create_prefetch_stream ¤
create_prefetch_stream(iterator: Iterator[T], *, mode: Literal['none', 'grain', 'flax', 'thread'], size: int, device: object | None = None) -> Iterator[T] | object
Prefetch an iterator through the requested upstream-backed adapter.
grain mode delegates to grain.experimental.device_put for Grain
datasets, flax mode delegates to flax.jax_utils.prefetch_to_device,
and thread mode keeps Datarax's closeable thread wrapper for custom
iterator lifecycle behavior.