Skip to content

Timing¤

Framework-agnostic timing with GPU synchronization support for accurate JAX benchmarking.

See Also¤


calibrax.profiling.timing ¤

Framework-agnostic timing with configurable result synchronization.

Provides TimingSample (frozen dataclass) and TimingCollector for measuring iteration throughput with per-batch timing breakdown. Uses time.perf_counter() exclusively for accurate benchmarking. Supports warm-up iteration exclusion and JIT compilation time measurement.

TimingSample dataclass ¤

TimingSample(*, wall_clock_sec: float, per_batch_times: tuple[float, ...], first_batch_time: float, num_batches: int, num_elements: int, compilation_time_sec: float | None = None, warmup_batches_excluded: int = 0)

Result of timing an iteration through a data pipeline.

Attributes:

Name Type Description
wall_clock_sec float

Total wall-clock time for the iteration.

per_batch_times tuple[float, ...]

Per-batch durations in seconds (warmup batches excluded).

first_batch_time float

Time from iteration start to first batch completion.

num_batches int

Number of batches consumed (including warmup).

num_elements int

Total elements processed (via count_fn).

compilation_time_sec float | None

JIT compilation time, if measured separately.

warmup_batches_excluded int

Number of leading batches excluded from per_batch_times.

wall_clock_sec instance-attribute ¤

wall_clock_sec: float

per_batch_times instance-attribute ¤

per_batch_times: tuple[float, ...]

first_batch_time instance-attribute ¤

first_batch_time: float

num_batches instance-attribute ¤

num_batches: int

num_elements instance-attribute ¤

num_elements: int

compilation_time_sec class-attribute instance-attribute ¤

compilation_time_sec: float | None = None

warmup_batches_excluded class-attribute instance-attribute ¤

warmup_batches_excluded: int = 0

to_dict ¤

to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤

from_dict(data: dict[str, Any]) -> TimingSample

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with TimingSample fields.

required

Returns:

Type Description
TimingSample

Reconstructed TimingSample instance.

TimingCollector ¤

TimingCollector(sync_fn: Callable[[Any], object] | None = None, warmup_iterations: int = 0)

Framework-agnostic timing with configurable GPU sync support.

Uses time.perf_counter() exclusively for accurate benchmarking. Supports configurable result synchronization via sync_fn and warm-up iteration exclusion for JIT-compiled workloads.

JAX dispatches operations asynchronously -- the host returns immediately while the device is still computing. Without an explicit synchronization barrier, perf_counter measures only host-side dispatch latency, not actual compute time. Pass a sync_fn that calls block_until_ready() on the workload result to force the host to wait for device completion before recording the timestamp.

Example -- JAX GPU timing with warm-up:

import jax.numpy as jnp

def run_step(batch):
    return jax.jit(step_fn)(batch)

collector = TimingCollector(
    sync_fn=lambda result: result.block_until_ready(),
    warmup_iterations=2,
)
sample = collector.measure_iteration(data_iter, num_batches=50, process_fn=run_step)
# sample.per_batch_times excludes the first 2 batches

Parameters:

Name Type Description Default
sync_fn Callable[[Any], object] | None

Synchronization function called with each batch result. For JAX: lambda result: result.block_until_ready() For PyTorch: lambda _: torch.cuda.synchronize() For CPU-only: None (default, no-op)

None
warmup_iterations int

Number of initial batches to exclude from per_batch_times statistics. They are still executed (important for JIT warm-up) but omitted from the timing result. Default: 0.

0

Parameters:

Name Type Description Default
sync_fn Callable[[Any], object] | None

Synchronization function called with each batch result.

None
warmup_iterations int

Number of initial batches to exclude from timing stats.

0

measure_iteration ¤

measure_iteration(iterator: Iterator[Any], num_batches: int | None = None, process_fn: Callable[[Any], Any] | None = None, count_fn: Callable[[Any], int] | None = None) -> TimingSample

Measure timing for batches from an iterator.

Warm-up batches (if configured) are executed but excluded from per_batch_times. wall_clock_sec covers the entire run including warm-up. num_batches reflects total batches consumed.

Parameters:

Name Type Description Default
iterator Iterator[Any]

Data iterator to measure.

required
num_batches int | None

Max batches to consume (including warmup). None exhausts iterator.

None
process_fn Callable[[Any], Any] | None

Optional per-batch function whose execution is timed. Defaults to identity (the yielded batch is treated as result).

None
count_fn Callable[[Any], int] | None

Function to count elements per batch. Default: 1 per batch.

None

Returns:

Type Description
TimingSample

TimingSample with timing measurements.

measure_compilation_time ¤

measure_compilation_time(fn: Callable[..., Any], *args: Any) -> float

Measure JIT compilation time for a JAX function.

Calls jax.jit(fn).lower(*args).compile() and times it. This measures the XLA compilation step only, not execution.

Parameters:

Name Type Description Default
fn Callable[..., Any]

JAX function to compile.

required
*args Any

Example arguments for lowering.

()

Returns:

Type Description
float

Compilation time in seconds.