Skip to content

Profiler¤

GPU memory profiling, hardware-adaptive optimization, and memory analysis for Datarax pipelines.

See Also¤

Overview¤

This module provides three components:

  • GPUMemoryProfiler — Detects GPU availability and reports memory usage (used/total/utilization). Also analyzes memory patterns across multiple measurements to detect leaks and high utilization.
  • MemoryOptimizer — Analyzes a pipeline function's memory footprint by measuring baseline, peak, and post-GC memory. Returns optimization suggestions.
  • AdaptiveOperation — Auto-detects hardware (CPU/GPU/TPU) and configures optimal tile sizes, precision, and batch sizes. Also provides Grain auto-optimization.

Quick Start¤

Check GPU memory¤

from calibrax.profiling import GPUMemoryProfiler

profiler = GPUMemoryProfiler()
usage = profiler.get_memory_usage()
print(f"GPU memory: {usage['gpu_memory_used_mb']:.1f} / {usage['gpu_memory_total_mb']:.1f} MB")
print(f"Utilization: {usage.get('gpu_memory_utilization', 0):.1%}")

Analyze pipeline memory¤

from calibrax.profiling import MemoryOptimizer

optimizer = MemoryOptimizer()
analysis = optimizer.analyze_pipeline_memory(pipeline_fn, sample_data)
print(f"Peak usage: {analysis['peak_usage_mb']:.1f} MB")
print(f"Memory efficiency: {analysis['memory_efficiency']:.1%}")
for suggestion in analysis["suggestions"]:
    print(f"  - {suggestion}")

calibrax.profiling ¤

Profiling: timing, resources, GPU, energy, FLOPs, hardware, roofline, compilation, complexity.

HARDWARE_SPECS module-attribute ¤

HARDWARE_SPECS: dict[str, dict[str, Any]] = {'tpu_v5e': {'peak_flops': 197000000000000.0, 'peak_flops_bf16': 197000000000000.0, 'memory_bandwidth': 1600000000000.0, 'critical_intensity': 123.125}, 'a100_80g': {'peak_flops': 312000000000000.0, 'peak_flops_bf16': 312000000000000.0, 'memory_bandwidth': 2039000000000.0, 'critical_intensity': 153.0, 'tensor_core_shapes': [(16, 16, 16), (16, 16, 8)]}, 'h100': {'peak_flops': 989000000000000.0, 'peak_flops_bf16': 989000000000000.0, 'memory_bandwidth': 3350000000000.0, 'critical_intensity': 295.0, 'tensor_core_shapes': [(16, 16, 16)]}, 'cpu_generic': {'peak_flops': 2000000000000.0, 'peak_flops_bf16': 2000000000000.0, 'memory_bandwidth': 200000000000.0, 'critical_intensity': 10.0, 'simd_width': 8}}

CompilationProfiler ¤

CompilationProfiler()

Analyzes JAX JIT compilation performance and optimization.

Instruments JIT-compiled functions to track compilation cache hits/misses, compilation times, and input shape consistency. Use profile_jit_compilation to wrap a function, then call get_result() for aggregated analysis.

profile_jit_compilation ¤

profile_jit_compilation(func: Callable[..., Any]) -> Callable[..., Any]

Create an instrumented wrapper that profiles JIT compilation.

The returned callable tracks cache hits/misses, compilation times, and input shape patterns. Results accumulate in this profiler instance.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to instrument.

required

Returns:

Type Description
Callable[..., Any]

Instrumented function with identical signature.

get_result ¤

get_result() -> CompilationResult

Get aggregated compilation profiling results.

Returns:

Type Description
CompilationResult

CompilationResult with cache statistics, timing, and recommendations.

estimate_xla_optimization ¤

estimate_xla_optimization(func: Callable[..., Any], *sample_args: Any) -> XLAOptimizationResult

Estimate XLA optimization effectiveness by analyzing HLO text.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to analyze.

required
*sample_args Any

Example arguments for lowering/compiling.

()

Returns:

Type Description
XLAOptimizationResult

XLAOptimizationResult with HLO analysis metrics.

reset ¤

reset() -> None

Reset all profiling state.

CompilationResult dataclass ¤

CompilationResult(*, cache_hit_rate: float, total_calls: int, cache_hits: int, cache_misses: int, avg_compilation_time_ms: float, max_compilation_time_ms: float, unique_signatures: int, health_score: float, health_level: str, recommendations: tuple[str, ...] = ())

Result of compilation profiling analysis.

Attributes:

Name Type Description
cache_hit_rate float

Fraction of calls that hit the compilation cache.

total_calls int

Total number of profiled function calls.

cache_hits int

Number of cache hits.

cache_misses int

Number of cache misses (triggering compilation).

avg_compilation_time_ms float

Average compilation time in milliseconds.

max_compilation_time_ms float

Maximum compilation time in milliseconds.

unique_signatures int

Number of unique function signatures compiled.

health_score float

Overall compilation health score (0-1).

health_level str

Human-readable health level.

recommendations tuple[str, ...]

Optimization recommendations.

cache_hit_rate instance-attribute ¤

cache_hit_rate: float

total_calls instance-attribute ¤

total_calls: int

cache_hits instance-attribute ¤

cache_hits: int

cache_misses instance-attribute ¤

cache_misses: int

avg_compilation_time_ms instance-attribute ¤

avg_compilation_time_ms: float

max_compilation_time_ms instance-attribute ¤

max_compilation_time_ms: float

unique_signatures instance-attribute ¤

unique_signatures: int

health_score instance-attribute ¤

health_score: float

health_level instance-attribute ¤

health_level: str

recommendations class-attribute instance-attribute ¤

recommendations: tuple[str, ...] = ()

to_dict ¤

to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤

from_dict(data: dict[str, Any]) -> CompilationResult

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with compilation result fields.

required

Returns:

Type Description
CompilationResult

Reconstructed CompilationResult instance.

XLAOptimizationResult dataclass ¤

XLAOptimizationResult(*, optimization_score: float, fusion_ratio: float, arithmetic_ratio: float, memory_ratio: float, total_kernels: int, recommendations: tuple[str, ...] = ())

Result of XLA optimization effectiveness analysis.

Attributes:

Name Type Description
optimization_score float

Overall optimization score (0-1).

fusion_ratio float

Fraction of fused kernels.

arithmetic_ratio float

Fraction of arithmetic operations.

memory_ratio float

Fraction of memory operations.

total_kernels int

Total number of HLO kernels.

recommendations tuple[str, ...]

Optimization recommendations.

optimization_score instance-attribute ¤

optimization_score: float

fusion_ratio instance-attribute ¤

fusion_ratio: float

arithmetic_ratio instance-attribute ¤

arithmetic_ratio: float

memory_ratio instance-attribute ¤

memory_ratio: float

total_kernels instance-attribute ¤

total_kernels: int

recommendations class-attribute instance-attribute ¤

recommendations: tuple[str, ...] = ()

to_dict ¤

to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤

from_dict(data: dict[str, Any]) -> XLAOptimizationResult

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with XLA optimization result fields.

required

Returns:

Type Description
XLAOptimizationResult

Reconstructed XLAOptimizationResult instance.

ComplexityResult dataclass ¤

ComplexityResult(*, total_parameters: int, parameter_memory_mb: float, largest_layer_name: str, largest_layer_params: int, input_shape: tuple[int, ...], estimated_memory_mb: float, total_estimated_operations: int, dominant_complexity: str, scaling_characteristics: dict[str, str] = dict())

Result of model complexity analysis.

Attributes:

Name Type Description
total_parameters int

Total number of trainable parameters.

parameter_memory_mb float

Memory consumed by parameters (float32).

largest_layer_name str

Name of the layer with the most parameters.

largest_layer_params int

Parameter count of the largest layer.

input_shape tuple[int, ...]

Shape of the analyzed input.

estimated_memory_mb float

Estimated total memory (params + activations).

total_estimated_operations int

Estimated total operations count.

dominant_complexity str

Name of the dominant operation type.

scaling_characteristics dict[str, str]

Mapping of operation type to complexity class.

total_parameters instance-attribute ¤

total_parameters: int

parameter_memory_mb instance-attribute ¤

parameter_memory_mb: float

largest_layer_name instance-attribute ¤

largest_layer_name: str

largest_layer_params instance-attribute ¤

largest_layer_params: int

input_shape instance-attribute ¤

input_shape: tuple[int, ...]

estimated_memory_mb instance-attribute ¤

estimated_memory_mb: float

total_estimated_operations instance-attribute ¤

total_estimated_operations: int

dominant_complexity instance-attribute ¤

dominant_complexity: str

scaling_characteristics class-attribute instance-attribute ¤

scaling_characteristics: dict[str, str] = field(default_factory=dict)

to_dict ¤

to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤

from_dict(data: dict[str, Any]) -> ComplexityResult

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with complexity result fields.

required

Returns:

Type Description
ComplexityResult

Reconstructed ComplexityResult instance.

EnergyMonitor ¤

EnergyMonitor(sample_interval_sec: float = 0.1)

Background energy monitoring via NVML and RAPL.

Uses daemon thread sampling at configurable interval. Gracefully degrades when NVML or RAPL is unavailable.

Usage:

with EnergyMonitor() as mon:
    # ... run benchmark ...
summary = mon.summary

Parameters:

Name Type Description Default
sample_interval_sec float

Seconds between energy samples.

0.1

samples property ¤

samples: list[EnergySample]

Return a copy of all collected samples.

summary property ¤

summary: EnergySummary

Compute aggregated energy summary.

Returns:

Type Description
EnergySummary

EnergySummary with totals, or None fields when unavailable.

EnergySample dataclass ¤

EnergySample(*, timestamp: float, gpu_power_watts: float | None, cpu_energy_joules: float | None, gpu_energy_joules: float | None)

Single energy measurement at a point in time.

Attributes:

Name Type Description
timestamp float

Time of measurement (perf_counter).

gpu_power_watts float | None

Instantaneous GPU power (None if unavailable).

cpu_energy_joules float | None

Cumulative CPU energy since monitoring start.

gpu_energy_joules float | None

Cumulative GPU energy since monitoring start.

timestamp instance-attribute ¤

timestamp: float

gpu_power_watts instance-attribute ¤

gpu_power_watts: float | None

cpu_energy_joules instance-attribute ¤

cpu_energy_joules: float | None

gpu_energy_joules instance-attribute ¤

gpu_energy_joules: float | None

EnergySummary dataclass ¤

EnergySummary(*, total_gpu_energy_joules: float | None, total_cpu_energy_joules: float | None, total_combined_energy_joules: float | None, mean_gpu_power_watts: float | None, peak_gpu_power_watts: float | None, duration_sec: float, num_samples: int)

Aggregated energy usage over a monitoring period.

Attributes:

Name Type Description
total_gpu_energy_joules float | None

Total GPU energy consumed.

total_cpu_energy_joules float | None

Total CPU energy consumed.

total_combined_energy_joules float | None

GPU + CPU combined.

mean_gpu_power_watts float | None

Average GPU power draw.

peak_gpu_power_watts float | None

Maximum GPU power draw.

duration_sec float

Monitoring duration.

num_samples int

Total samples collected.

total_gpu_energy_joules instance-attribute ¤

total_gpu_energy_joules: float | None

total_cpu_energy_joules instance-attribute ¤

total_cpu_energy_joules: float | None

total_combined_energy_joules instance-attribute ¤

total_combined_energy_joules: float | None

mean_gpu_power_watts instance-attribute ¤

mean_gpu_power_watts: float | None

peak_gpu_power_watts instance-attribute ¤

peak_gpu_power_watts: float | None

duration_sec instance-attribute ¤

duration_sec: float

num_samples instance-attribute ¤

num_samples: int

FlopsCounter ¤

Count FLOPs of JAX functions via jaxpr analysis.

Uses jax.make_jaxpr to trace the function and counts FLOPs for each primitive based on operation-specific rules.

For NNX models that use stochastic operations (dropout, etc.), use flax.nnx.tabulate(model, *args, compute_flops=True) instead — it handles NNX state management internally.

count ¤

count(fn: Callable[..., Any], *args: Any, static_argnums: tuple[int, ...] = ()) -> FlopsResult

Count FLOPs for a function with given example arguments.

Parameters:

Name Type Description Default
fn Callable[..., Any]

JAX function to analyze.

required
*args Any

Example arguments for tracing.

()
static_argnums tuple[int, ...]

Argument indices to treat as static.

()

Returns:

Type Description
FlopsResult

FlopsResult with FLOP count and breakdown.

FlopsResult dataclass ¤

FlopsResult(*, total_flops: int, flops_by_operation: dict[str, int], num_operations: int, function_name: str)

Result of FLOP counting for a function.

Attributes:

Name Type Description
total_flops int

Total estimated FLOPs.

flops_by_operation dict[str, int]

Breakdown by primitive operation name.

num_operations int

Number of JAX primitives in the trace.

function_name str

Name of the analyzed function.

total_flops instance-attribute ¤

total_flops: int

flops_by_operation instance-attribute ¤

flops_by_operation: dict[str, int]

num_operations instance-attribute ¤

num_operations: int

function_name instance-attribute ¤

function_name: str

AdaptiveOperation ¤

AdaptiveOperation()

Hardware-adaptive operations with auto-detection.

Detects the current JAX backend (CPU/GPU/TPU) and provides optimized configuration and shape padding.

config instance-attribute ¤

config = _detect_hardware()

optimize_shapes ¤

optimize_shapes(*shapes: tuple[int, ...]) -> list[tuple[int, ...]]

Pad tensor shapes to align with hardware tile size.

Parameters:

Name Type Description Default
*shapes tuple[int, ...]

Variable number of tensor shapes to optimize.

()

Returns:

Type Description
list[tuple[int, ...]]

List of optimized shapes padded to tile_size multiples.

GPUMemoryProfiler ¤

GPUMemoryProfiler()

GPU memory profiling satisfying GPUProfilerProtocol.

Uses multi-fallback strategy: memory_stats -> xla_bridge -> zeros.

has_gpu instance-attribute ¤

has_gpu = len(devices('gpu')) > 0

get_memory_usage ¤

get_memory_usage() -> dict[str, float]

Get current GPU memory usage statistics.

Returns:

Type Description
dict[str, float]

Dictionary with gpu_memory_used_mb, gpu_memory_total_mb,

dict[str, float]

and optionally gpu_memory_utilization.

get_utilization ¤

get_utilization() -> float

Get GPU utilization percentage for ResourceMonitor.

Returns:

Type Description
float

GPU memory utilization as percentage (0-100), or 0.0.

get_clock_info ¤

get_clock_info() -> dict[str, float]

Get current GPU clock frequencies via NVML.

Returns:

Type Description
dict[str, float]

Dictionary with 'gpu_clock_mhz' and 'mem_clock_mhz' keys.

dict[str, float]

Returns zeros if NVML is unavailable or query fails.

get_power_info ¤

get_power_info() -> dict[str, float]

Get current GPU power draw and limit via NVML.

Returns:

Type Description
dict[str, float]

Dictionary with 'power_draw_w' and 'power_limit_w' keys.

dict[str, float]

Returns zeros if NVML is unavailable or query fails.

analyze_memory_pattern ¤

analyze_memory_pattern(measurements: list[dict[str, float]]) -> list[str]

Analyze memory usage patterns and suggest optimizations.

Parameters:

Name Type Description Default
measurements list[dict[str, float]]

List of memory usage dictionaries.

required

Returns:

Type Description
list[str]

List of optimization suggestion strings.

HardwareConfig dataclass ¤

HardwareConfig(*, platform: str, precision: str, tile_size: int, critical_batch_size: int, memory_layout: str, use_vmem_optimization: bool)

Hardware-specific optimization configuration.

Attributes:

Name Type Description
platform str

Detected platform ("cpu", "tpu", "gpu_modern", "gpu_legacy").

precision str

Recommended floating-point precision string.

tile_size int

Tile size for matrix operation alignment.

critical_batch_size int

Optimal batch size for the platform.

memory_layout str

Memory layout preference.

use_vmem_optimization bool

Whether VMEM optimization is available.

platform instance-attribute ¤

platform: str

precision instance-attribute ¤

precision: str

tile_size instance-attribute ¤

tile_size: int

critical_batch_size instance-attribute ¤

critical_batch_size: int

memory_layout instance-attribute ¤

memory_layout: str

use_vmem_optimization instance-attribute ¤

use_vmem_optimization: bool

MemoryAnalysis dataclass ¤

MemoryAnalysis(*, baseline_memory_mb: float, peak_memory_mb: float, peak_usage_mb: float, retained_memory_mb: float, memory_efficiency: float, suggestions: tuple[str, ...] = ())

Result of pipeline memory analysis.

Attributes:

Name Type Description
baseline_memory_mb float

Memory usage before pipeline execution.

peak_memory_mb float

Memory usage at peak during execution.

peak_usage_mb float

Peak usage above baseline.

retained_memory_mb float

Memory retained after GC.

memory_efficiency float

Ratio of freed memory to peak usage.

suggestions tuple[str, ...]

Optimization suggestions.

baseline_memory_mb instance-attribute ¤

baseline_memory_mb: float

peak_memory_mb instance-attribute ¤

peak_memory_mb: float

peak_usage_mb instance-attribute ¤

peak_usage_mb: float

retained_memory_mb instance-attribute ¤

retained_memory_mb: float

memory_efficiency instance-attribute ¤

memory_efficiency: float

suggestions class-attribute instance-attribute ¤

suggestions: tuple[str, ...] = ()

MemoryOptimizer ¤

Memory optimization analysis for pipeline functions.

analyze_pipeline_memory ¤

analyze_pipeline_memory(pipeline_fn: Callable[[Any], Any], sample_data: Any) -> MemoryAnalysis | None

Analyze memory usage of a pipeline function.

Parameters:

Name Type Description Default
pipeline_fn Callable[[Any], Any]

Function to analyze.

required
sample_data Any

Sample input data.

required

Returns:

Type Description
MemoryAnalysis | None

MemoryAnalysis with measurements and suggestions,

MemoryAnalysis | None

or None if the pipeline raises an exception.

GPUProfilerProtocol ¤

Bases: Protocol

Protocol for GPU profilers providing utilization and memory data.

get_utilization ¤

get_utilization() -> float

Get current GPU utilization percentage.

Returns:

Type Description
float

GPU utilization as a percentage (0-100).

get_memory_usage ¤

get_memory_usage() -> dict[str, float]

Get current GPU memory usage statistics.

Returns:

Type Description
dict[str, float]

Dictionary with at least 'gpu_memory_used_mb' key.

get_clock_info ¤

get_clock_info() -> dict[str, float]

Get current GPU clock frequencies.

Returns:

Type Description
dict[str, float]

Dictionary with 'gpu_clock_mhz' and 'mem_clock_mhz' keys.

get_power_info ¤

get_power_info() -> dict[str, float]

Get current GPU power draw and limits.

Returns:

Type Description
dict[str, float]

Dictionary with 'power_draw_w' and 'power_limit_w' keys.

ResourceMonitor ¤

ResourceMonitor(sample_interval_sec: float = 0.1, gpu_profiler: GPUProfilerProtocol | None = None)

Background 10Hz resource sampling via context manager.

Usage:

with ResourceMonitor() as mon:
    # ... run benchmark ...
summary = mon.summary

Parameters:

Name Type Description Default
sample_interval_sec float

Seconds between samples (default 0.1 = 10Hz).

0.1
gpu_profiler GPUProfilerProtocol | None

Optional profiler satisfying GPUProfilerProtocol.

None

Parameters:

Name Type Description Default
sample_interval_sec float

Seconds between resource samples.

0.1
gpu_profiler GPUProfilerProtocol | None

Optional GPU profiler for GPU metrics.

None

samples property ¤

samples: list[ResourceSample]

Return a copy of all collected samples.

summary property ¤

summary: ResourceSummary

Compute aggregated summary from collected samples.

Returns:

Type Description
ResourceSummary

ResourceSummary with aggregated metrics, or zeroed summary

ResourceSummary

if no samples were collected.

ResourceSample dataclass ¤

ResourceSample(*, timestamp: float, cpu_percent: float, rss_mb: float, gpu_util: float | None, gpu_mem_mb: float | None, gpu_clock_mhz: float | None = None, gpu_power_w: float | None = None)

Single resource measurement at a point in time.

Attributes:

Name Type Description
timestamp float

Time of measurement (perf_counter).

cpu_percent float

CPU utilization percentage.

rss_mb float

Resident set size in MB.

gpu_util float | None

GPU utilization percentage (None if no GPU).

gpu_mem_mb float | None

GPU memory used in MB (None if no GPU).

timestamp instance-attribute ¤

timestamp: float

cpu_percent instance-attribute ¤

cpu_percent: float

rss_mb instance-attribute ¤

rss_mb: float

gpu_util instance-attribute ¤

gpu_util: float | None

gpu_mem_mb instance-attribute ¤

gpu_mem_mb: float | None

gpu_clock_mhz class-attribute instance-attribute ¤

gpu_clock_mhz: float | None = None

gpu_power_w class-attribute instance-attribute ¤

gpu_power_w: float | None = None

ResourceSummary dataclass ¤

ResourceSummary(*, peak_rss_mb: float, mean_rss_mb: float, peak_gpu_mem_mb: float | None, mean_gpu_util: float | None, memory_growth_mb: float, num_samples: int, duration_sec: float, mean_gpu_clock_mhz: float | None = None, mean_gpu_power_w: float | None = None)

Aggregated resource usage over a monitoring period.

Attributes:

Name Type Description
peak_rss_mb float

Maximum RSS observed.

mean_rss_mb float

Average RSS across all samples.

peak_gpu_mem_mb float | None

Maximum GPU memory (None if no GPU).

mean_gpu_util float | None

Average GPU utilization (None if no GPU).

memory_growth_mb float

Last RSS minus first RSS (positive = growth).

num_samples int

Total samples collected.

duration_sec float

Time span of monitoring.

peak_rss_mb instance-attribute ¤

peak_rss_mb: float

mean_rss_mb instance-attribute ¤

mean_rss_mb: float

peak_gpu_mem_mb instance-attribute ¤

peak_gpu_mem_mb: float | None

mean_gpu_util instance-attribute ¤

mean_gpu_util: float | None

memory_growth_mb instance-attribute ¤

memory_growth_mb: float

num_samples instance-attribute ¤

num_samples: int

duration_sec instance-attribute ¤

duration_sec: float

mean_gpu_clock_mhz class-attribute instance-attribute ¤

mean_gpu_clock_mhz: float | None = None

mean_gpu_power_w class-attribute instance-attribute ¤

mean_gpu_power_w: float | None = None

to_dict ¤

to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

Optional GPU fields are included only when not None. Numeric values are converted to Python primitives for JAX scalar safety.

Returns:

Type Description
dict[str, Any]

Dictionary representation with all resource summary fields.

from_dict classmethod ¤

from_dict(data: dict[str, Any]) -> ResourceSummary

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with resource summary fields.

required

Returns:

Type Description
ResourceSummary

Reconstructed ResourceSummary instance.

RooflineAnalyzer dataclass ¤

RooflineAnalyzer(*, hardware_specs: dict[str, Any] = detect_hardware_specs())

Analyzes operation performance against hardware roofline limits.

Uses measured execution time and estimated FLOPs to determine whether an operation is compute-bound or memory-bound, and how efficiently it uses the available hardware resources.

Attributes:

Name Type Description
hardware_specs dict[str, Any]

Hardware specification dictionary (auto-detected if not provided).

hardware_specs class-attribute instance-attribute ¤

hardware_specs: dict[str, Any] = field(default_factory=detect_hardware_specs)

analyze_operation ¤

analyze_operation(func: Callable[..., Any], inputs: list[Array], *, flops_override: int | None = None) -> RooflineResult

Perform roofline analysis on a JAX operation.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to analyze.

required
inputs list[Array]

Input arrays for the function.

required
flops_override int | None

If provided, use this FLOP count instead of estimating. For accurate results, pass the output of FlopsCounter.count().

None

Returns:

Type Description
RooflineResult

RooflineResult with bottleneck classification and recommendations.

RooflineResult dataclass ¤

RooflineResult(*, arithmetic_intensity: float, critical_intensity: float, memory_bandwidth_utilization: float, flops_utilization: float, bottleneck: str, efficiency: float, execution_time_ms: float, recommendations: tuple[str, ...] = ())

Result of a roofline analysis on a JAX operation.

Attributes:

Name Type Description
arithmetic_intensity float

Achieved FLOPs per byte of memory traffic.

critical_intensity float

Hardware's ridge point (FLOPs/byte).

memory_bandwidth_utilization float

Fraction of peak memory bandwidth used.

flops_utilization float

Fraction of peak FLOPs achieved.

bottleneck str

Either "memory_bandwidth" or "compute".

efficiency float

Utilization of the binding resource.

execution_time_ms float

Measured execution time in milliseconds.

recommendations tuple[str, ...]

Optimization suggestions.

arithmetic_intensity instance-attribute ¤

arithmetic_intensity: float

critical_intensity instance-attribute ¤

critical_intensity: float

memory_bandwidth_utilization instance-attribute ¤

memory_bandwidth_utilization: float

flops_utilization instance-attribute ¤

flops_utilization: float

bottleneck instance-attribute ¤

bottleneck: str

efficiency instance-attribute ¤

efficiency: float

execution_time_ms instance-attribute ¤

execution_time_ms: float

recommendations class-attribute instance-attribute ¤

recommendations: tuple[str, ...] = ()

to_dict ¤

to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤

from_dict(data: dict[str, Any]) -> RooflineResult

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with roofline result fields.

required

Returns:

Type Description
RooflineResult

Reconstructed RooflineResult instance.

TimingCollector ¤

TimingCollector(sync_fn: Callable[[Any], object] | None = None, warmup_iterations: int = 0)

Framework-agnostic timing with configurable GPU sync support.

Uses time.perf_counter() exclusively for accurate benchmarking. Supports configurable result synchronization via sync_fn and warm-up iteration exclusion for JIT-compiled workloads.

JAX dispatches operations asynchronously -- the host returns immediately while the device is still computing. Without an explicit synchronization barrier, perf_counter measures only host-side dispatch latency, not actual compute time. Pass a sync_fn that calls block_until_ready() on the workload result to force the host to wait for device completion before recording the timestamp.

Example -- JAX GPU timing with warm-up:

import jax.numpy as jnp

def run_step(batch):
    return jax.jit(step_fn)(batch)

collector = TimingCollector(
    sync_fn=lambda result: result.block_until_ready(),
    warmup_iterations=2,
)
sample = collector.measure_iteration(data_iter, num_batches=50, process_fn=run_step)
# sample.per_batch_times excludes the first 2 batches

Parameters:

Name Type Description Default
sync_fn Callable[[Any], object] | None

Synchronization function called with each batch result. For JAX: lambda result: result.block_until_ready() For PyTorch: lambda _: torch.cuda.synchronize() For CPU-only: None (default, no-op)

None
warmup_iterations int

Number of initial batches to exclude from per_batch_times statistics. They are still executed (important for JIT warm-up) but omitted from the timing result. Default: 0.

0

Parameters:

Name Type Description Default
sync_fn Callable[[Any], object] | None

Synchronization function called with each batch result.

None
warmup_iterations int

Number of initial batches to exclude from timing stats.

0

measure_iteration ¤

measure_iteration(iterator: Iterator[Any], num_batches: int | None = None, process_fn: Callable[[Any], Any] | None = None, count_fn: Callable[[Any], int] | None = None) -> TimingSample

Measure timing for batches from an iterator.

Warm-up batches (if configured) are executed but excluded from per_batch_times. wall_clock_sec covers the entire run including warm-up. num_batches reflects total batches consumed.

Parameters:

Name Type Description Default
iterator Iterator[Any]

Data iterator to measure.

required
num_batches int | None

Max batches to consume (including warmup). None exhausts iterator.

None
process_fn Callable[[Any], Any] | None

Optional per-batch function whose execution is timed. Defaults to identity (the yielded batch is treated as result).

None
count_fn Callable[[Any], int] | None

Function to count elements per batch. Default: 1 per batch.

None

Returns:

Type Description
TimingSample

TimingSample with timing measurements.

measure_compilation_time ¤

measure_compilation_time(fn: Callable[..., Any], *args: Any) -> float

Measure JIT compilation time for a JAX function.

Calls jax.jit(fn).lower(*args).compile() and times it. This measures the XLA compilation step only, not execution.

Parameters:

Name Type Description Default
fn Callable[..., Any]

JAX function to compile.

required
*args Any

Example arguments for lowering.

()

Returns:

Type Description
float

Compilation time in seconds.

TimingSample dataclass ¤

TimingSample(*, wall_clock_sec: float, per_batch_times: tuple[float, ...], first_batch_time: float, num_batches: int, num_elements: int, compilation_time_sec: float | None = None, warmup_batches_excluded: int = 0)

Result of timing an iteration through a data pipeline.

Attributes:

Name Type Description
wall_clock_sec float

Total wall-clock time for the iteration.

per_batch_times tuple[float, ...]

Per-batch durations in seconds (warmup batches excluded).

first_batch_time float

Time from iteration start to first batch completion.

num_batches int

Number of batches consumed (including warmup).

num_elements int

Total elements processed (via count_fn).

compilation_time_sec float | None

JIT compilation time, if measured separately.

warmup_batches_excluded int

Number of leading batches excluded from per_batch_times.

wall_clock_sec instance-attribute ¤

wall_clock_sec: float

per_batch_times instance-attribute ¤

per_batch_times: tuple[float, ...]

first_batch_time instance-attribute ¤

first_batch_time: float

num_batches instance-attribute ¤

num_batches: int

num_elements instance-attribute ¤

num_elements: int

compilation_time_sec class-attribute instance-attribute ¤

compilation_time_sec: float | None = None

warmup_batches_excluded class-attribute instance-attribute ¤

warmup_batches_excluded: int = 0

to_dict ¤

to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤

from_dict(data: dict[str, Any]) -> TimingSample

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with TimingSample fields.

required

Returns:

Type Description
TimingSample

Reconstructed TimingSample instance.

TraceLinker ¤

Links JAX profiler traces to benchmark runs.

Usage:

linker = TraceLinker()
with linker.trace("/tmp/my_trace") as ref:
    # ... run workload ...
print(ref.trace_dir)  # "/tmp/my_trace"

trace ¤

trace(log_dir: str | Path, *, run_id: str | None = None, create_perfetto_link: bool = False, create_perfetto_trace: bool = False) -> Any

Start an XLA profiling session and record output metadata.

Wraps jax.profiler.trace() and records the output directory path as a TraceReference for downstream Store linkage.

Parameters:

Name Type Description Default
log_dir str | Path

Directory for trace output files.

required
run_id str | None

Optional benchmark run ID to associate with the trace.

None
create_perfetto_link bool

Whether to create a Perfetto link (passed to JAX).

False
create_perfetto_trace bool

Whether to create a Perfetto trace (passed to JAX).

False

Yields:

Type Description
Any

TraceReference with the trace directory and optional run ID.

TraceReference dataclass ¤

TraceReference(*, trace_dir: str, run_id: str | None = None)

Reference to a JAX profiler trace output.

Attributes:

Name Type Description
trace_dir str

Directory where the trace files were written.

run_id str | None

Optional benchmark run ID to link the trace to.

trace_dir instance-attribute ¤

trace_dir: str

run_id class-attribute instance-attribute ¤

run_id: str | None = None

to_dict ¤

to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤

from_dict(data: dict[str, Any]) -> TraceReference

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with trace reference fields.

required

Returns:

Type Description
TraceReference

Reconstructed TraceReference instance.

analyze_complexity ¤

analyze_complexity(model: Module, input_shape: tuple[int, ...]) -> ComplexityResult

Analyze complexity of a Flax NNX module.

Performs parameter counting, memory estimation, computational complexity analysis, and scaling characterization.

Parameters:

Name Type Description Default
model Module

Flax NNX model to analyze.

required
input_shape tuple[int, ...]

Shape of input data (including batch dimension).

required

Returns:

Type Description
ComplexityResult

ComplexityResult with detailed complexity metrics.

detect_hardware_specs ¤

detect_hardware_specs() -> dict[str, Any]

Detect current hardware and return appropriate specifications.

Uses jax.default_backend() to determine the accelerator type and returns pre-configured specs for that platform.

Returns:

Type Description
dict[str, Any]

Hardware specification dictionary with peak_flops, memory_bandwidth,

dict[str, Any]

and critical_intensity keys (among others).

measure_execution_time ¤

measure_execution_time(func: Callable[..., Any], inputs: list[Array], warmup: int = 3, iterations: int = 10) -> float

Measure execution time of a JAX function with synchronization.

JIT-compiles the function, runs warmup iterations, then times iterations executions with block_until_ready() barriers.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to benchmark.

required
inputs list[Array]

Input arguments as a list of arrays.

required
warmup int

Number of warmup iterations (for JIT compilation).

3
iterations int

Number of timed iterations.

10

Returns:

Type Description
float

Average execution time in seconds.

carbon ¤

Carbon emissions tracking via CodeCarbon integration.

Wraps the codecarbon.EmissionsTracker as a context manager, exposing emissions data as a frozen CarbonResult dataclass. Requires the optional codecarbon dependency (uv pip install "calibrax[codecarbon]").

CODECARBON_AVAILABLE module-attribute ¤

CODECARBON_AVAILABLE = True

logger module-attribute ¤

logger = getLogger(__name__)

CarbonResult dataclass ¤

CarbonResult(*, emissions_kg_co2: float, energy_consumed_kwh: float, duration_sec: float, country_iso_code: str | None = None)

Result of carbon emissions measurement.

Attributes:

Name Type Description
emissions_kg_co2 float

Total CO2 emissions in kilograms.

energy_consumed_kwh float

Total energy consumed in kilowatt-hours.

duration_sec float

Duration of the tracked period in seconds.

country_iso_code str | None

ISO code of the country used for carbon intensity.

emissions_kg_co2 instance-attribute ¤
emissions_kg_co2: float
energy_consumed_kwh instance-attribute ¤
energy_consumed_kwh: float
duration_sec instance-attribute ¤
duration_sec: float
country_iso_code class-attribute instance-attribute ¤
country_iso_code: str | None = None
to_dict ¤
to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤
from_dict(data: dict[str, Any]) -> CarbonResult

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with carbon result fields.

required

Returns:

Type Description
CarbonResult

Reconstructed CarbonResult instance.

CarbonTracker ¤

CarbonTracker(country_iso_code: str | None = None, log_level: str = 'warning')

Context manager for tracking carbon emissions via CodeCarbon.

Requires the codecarbon package. Install with:

uv pip install "calibrax[codecarbon]"

Usage:

with CarbonTracker() as tracker:
    # ... run workload ...
result = tracker.result()
print(f"Emissions: {result.emissions_kg_co2:.6f} kg CO2")

Parameters:

Name Type Description Default
country_iso_code str | None

Optional ISO country code for regional carbon intensity.

None
log_level str

Logging level for CodeCarbon (default: "warning").

'warning'

Raises:

Type Description
ImportError

If codecarbon is not installed.

Parameters:

Name Type Description Default
country_iso_code str | None

Optional ISO country code.

None
log_level str

CodeCarbon logging level.

'warning'

Raises:

Type Description
ImportError

If codecarbon is not installed.

result ¤
result() -> CarbonResult

Get the carbon emissions result.

Call this after exiting the context manager.

Returns:

Type Description
CarbonResult

CarbonResult with emissions, energy, and duration data.

compilation ¤

JIT compilation profiler for JAX.

Analyzes JIT compilation efficiency, cache hit rates, XLA optimization effectiveness, and provides recommendations for compilation optimization.

logger module-attribute ¤

logger = getLogger(__name__)

CompilationResult dataclass ¤

CompilationResult(*, cache_hit_rate: float, total_calls: int, cache_hits: int, cache_misses: int, avg_compilation_time_ms: float, max_compilation_time_ms: float, unique_signatures: int, health_score: float, health_level: str, recommendations: tuple[str, ...] = ())

Result of compilation profiling analysis.

Attributes:

Name Type Description
cache_hit_rate float

Fraction of calls that hit the compilation cache.

total_calls int

Total number of profiled function calls.

cache_hits int

Number of cache hits.

cache_misses int

Number of cache misses (triggering compilation).

avg_compilation_time_ms float

Average compilation time in milliseconds.

max_compilation_time_ms float

Maximum compilation time in milliseconds.

unique_signatures int

Number of unique function signatures compiled.

health_score float

Overall compilation health score (0-1).

health_level str

Human-readable health level.

recommendations tuple[str, ...]

Optimization recommendations.

cache_hit_rate instance-attribute ¤
cache_hit_rate: float
total_calls instance-attribute ¤
total_calls: int
cache_hits instance-attribute ¤
cache_hits: int
cache_misses instance-attribute ¤
cache_misses: int
avg_compilation_time_ms instance-attribute ¤
avg_compilation_time_ms: float
max_compilation_time_ms instance-attribute ¤
max_compilation_time_ms: float
unique_signatures instance-attribute ¤
unique_signatures: int
health_score instance-attribute ¤
health_score: float
health_level instance-attribute ¤
health_level: str
recommendations class-attribute instance-attribute ¤
recommendations: tuple[str, ...] = ()
to_dict ¤
to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤
from_dict(data: dict[str, Any]) -> CompilationResult

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with compilation result fields.

required

Returns:

Type Description
CompilationResult

Reconstructed CompilationResult instance.

XLAOptimizationResult dataclass ¤

XLAOptimizationResult(*, optimization_score: float, fusion_ratio: float, arithmetic_ratio: float, memory_ratio: float, total_kernels: int, recommendations: tuple[str, ...] = ())

Result of XLA optimization effectiveness analysis.

Attributes:

Name Type Description
optimization_score float

Overall optimization score (0-1).

fusion_ratio float

Fraction of fused kernels.

arithmetic_ratio float

Fraction of arithmetic operations.

memory_ratio float

Fraction of memory operations.

total_kernels int

Total number of HLO kernels.

recommendations tuple[str, ...]

Optimization recommendations.

optimization_score instance-attribute ¤
optimization_score: float
fusion_ratio instance-attribute ¤
fusion_ratio: float
arithmetic_ratio instance-attribute ¤
arithmetic_ratio: float
memory_ratio instance-attribute ¤
memory_ratio: float
total_kernels instance-attribute ¤
total_kernels: int
recommendations class-attribute instance-attribute ¤
recommendations: tuple[str, ...] = ()
to_dict ¤
to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤
from_dict(data: dict[str, Any]) -> XLAOptimizationResult

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with XLA optimization result fields.

required

Returns:

Type Description
XLAOptimizationResult

Reconstructed XLAOptimizationResult instance.

CompilationProfiler ¤

CompilationProfiler()

Analyzes JAX JIT compilation performance and optimization.

Instruments JIT-compiled functions to track compilation cache hits/misses, compilation times, and input shape consistency. Use profile_jit_compilation to wrap a function, then call get_result() for aggregated analysis.

profile_jit_compilation ¤
profile_jit_compilation(func: Callable[..., Any]) -> Callable[..., Any]

Create an instrumented wrapper that profiles JIT compilation.

The returned callable tracks cache hits/misses, compilation times, and input shape patterns. Results accumulate in this profiler instance.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to instrument.

required

Returns:

Type Description
Callable[..., Any]

Instrumented function with identical signature.

get_result ¤
get_result() -> CompilationResult

Get aggregated compilation profiling results.

Returns:

Type Description
CompilationResult

CompilationResult with cache statistics, timing, and recommendations.

estimate_xla_optimization ¤
estimate_xla_optimization(func: Callable[..., Any], *sample_args: Any) -> XLAOptimizationResult

Estimate XLA optimization effectiveness by analyzing HLO text.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to analyze.

required
*sample_args Any

Example arguments for lowering/compiling.

()

Returns:

Type Description
XLAOptimizationResult

XLAOptimizationResult with HLO analysis metrics.

reset ¤
reset() -> None

Reset all profiling state.

complexity ¤

Model complexity analysis for Flax NNX modules.

Provides parameter counts, memory usage estimates, computational complexity analysis, and scaling characteristics for any NNX module.

ComplexityResult dataclass ¤

ComplexityResult(*, total_parameters: int, parameter_memory_mb: float, largest_layer_name: str, largest_layer_params: int, input_shape: tuple[int, ...], estimated_memory_mb: float, total_estimated_operations: int, dominant_complexity: str, scaling_characteristics: dict[str, str] = dict())

Result of model complexity analysis.

Attributes:

Name Type Description
total_parameters int

Total number of trainable parameters.

parameter_memory_mb float

Memory consumed by parameters (float32).

largest_layer_name str

Name of the layer with the most parameters.

largest_layer_params int

Parameter count of the largest layer.

input_shape tuple[int, ...]

Shape of the analyzed input.

estimated_memory_mb float

Estimated total memory (params + activations).

total_estimated_operations int

Estimated total operations count.

dominant_complexity str

Name of the dominant operation type.

scaling_characteristics dict[str, str]

Mapping of operation type to complexity class.

total_parameters instance-attribute ¤
total_parameters: int
parameter_memory_mb instance-attribute ¤
parameter_memory_mb: float
largest_layer_name instance-attribute ¤
largest_layer_name: str
largest_layer_params instance-attribute ¤
largest_layer_params: int
input_shape instance-attribute ¤
input_shape: tuple[int, ...]
estimated_memory_mb instance-attribute ¤
estimated_memory_mb: float
total_estimated_operations instance-attribute ¤
total_estimated_operations: int
dominant_complexity instance-attribute ¤
dominant_complexity: str
scaling_characteristics class-attribute instance-attribute ¤
scaling_characteristics: dict[str, str] = field(default_factory=dict)
to_dict ¤
to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤
from_dict(data: dict[str, Any]) -> ComplexityResult

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with complexity result fields.

required

Returns:

Type Description
ComplexityResult

Reconstructed ComplexityResult instance.

analyze_complexity ¤

analyze_complexity(model: Module, input_shape: tuple[int, ...]) -> ComplexityResult

Analyze complexity of a Flax NNX module.

Performs parameter counting, memory estimation, computational complexity analysis, and scaling characterization.

Parameters:

Name Type Description Default
model Module

Flax NNX model to analyze.

required
input_shape tuple[int, ...]

Shape of input data (including batch dimension).

required

Returns:

Type Description
ComplexityResult

ComplexityResult with detailed complexity metrics.

energy ¤

Energy monitoring via NVML (GPU) and RAPL (CPU).

Provides EnergyMonitor context manager for tracking power consumption during benchmark execution. Gracefully degrades when hardware interfaces are unavailable.

logger module-attribute ¤

logger = getLogger(__name__)

EnergySample dataclass ¤

EnergySample(*, timestamp: float, gpu_power_watts: float | None, cpu_energy_joules: float | None, gpu_energy_joules: float | None)

Single energy measurement at a point in time.

Attributes:

Name Type Description
timestamp float

Time of measurement (perf_counter).

gpu_power_watts float | None

Instantaneous GPU power (None if unavailable).

cpu_energy_joules float | None

Cumulative CPU energy since monitoring start.

gpu_energy_joules float | None

Cumulative GPU energy since monitoring start.

timestamp instance-attribute ¤
timestamp: float
gpu_power_watts instance-attribute ¤
gpu_power_watts: float | None
cpu_energy_joules instance-attribute ¤
cpu_energy_joules: float | None
gpu_energy_joules instance-attribute ¤
gpu_energy_joules: float | None

EnergySummary dataclass ¤

EnergySummary(*, total_gpu_energy_joules: float | None, total_cpu_energy_joules: float | None, total_combined_energy_joules: float | None, mean_gpu_power_watts: float | None, peak_gpu_power_watts: float | None, duration_sec: float, num_samples: int)

Aggregated energy usage over a monitoring period.

Attributes:

Name Type Description
total_gpu_energy_joules float | None

Total GPU energy consumed.

total_cpu_energy_joules float | None

Total CPU energy consumed.

total_combined_energy_joules float | None

GPU + CPU combined.

mean_gpu_power_watts float | None

Average GPU power draw.

peak_gpu_power_watts float | None

Maximum GPU power draw.

duration_sec float

Monitoring duration.

num_samples int

Total samples collected.

total_gpu_energy_joules instance-attribute ¤
total_gpu_energy_joules: float | None
total_cpu_energy_joules instance-attribute ¤
total_cpu_energy_joules: float | None
total_combined_energy_joules instance-attribute ¤
total_combined_energy_joules: float | None
mean_gpu_power_watts instance-attribute ¤
mean_gpu_power_watts: float | None
peak_gpu_power_watts instance-attribute ¤
peak_gpu_power_watts: float | None
duration_sec instance-attribute ¤
duration_sec: float
num_samples instance-attribute ¤
num_samples: int

EnergyMonitor ¤

EnergyMonitor(sample_interval_sec: float = 0.1)

Background energy monitoring via NVML and RAPL.

Uses daemon thread sampling at configurable interval. Gracefully degrades when NVML or RAPL is unavailable.

Usage:

with EnergyMonitor() as mon:
    # ... run benchmark ...
summary = mon.summary

Parameters:

Name Type Description Default
sample_interval_sec float

Seconds between energy samples.

0.1
samples property ¤
samples: list[EnergySample]

Return a copy of all collected samples.

summary property ¤
summary: EnergySummary

Compute aggregated energy summary.

Returns:

Type Description
EnergySummary

EnergySummary with totals, or None fields when unavailable.

flops ¤

FLOP counting via JAX's jaxpr tracing.

Provides FlopsCounter for estimating FLOPs of JAX functions by analyzing their Jaxpr intermediate representation.

logger module-attribute ¤

logger = getLogger(__name__)

FlopsResult dataclass ¤

FlopsResult(*, total_flops: int, flops_by_operation: dict[str, int], num_operations: int, function_name: str)

Result of FLOP counting for a function.

Attributes:

Name Type Description
total_flops int

Total estimated FLOPs.

flops_by_operation dict[str, int]

Breakdown by primitive operation name.

num_operations int

Number of JAX primitives in the trace.

function_name str

Name of the analyzed function.

total_flops instance-attribute ¤
total_flops: int
flops_by_operation instance-attribute ¤
flops_by_operation: dict[str, int]
num_operations instance-attribute ¤
num_operations: int
function_name instance-attribute ¤
function_name: str

FlopsCounter ¤

Count FLOPs of JAX functions via jaxpr analysis.

Uses jax.make_jaxpr to trace the function and counts FLOPs for each primitive based on operation-specific rules.

For NNX models that use stochastic operations (dropout, etc.), use flax.nnx.tabulate(model, *args, compute_flops=True) instead — it handles NNX state management internally.

count ¤
count(fn: Callable[..., Any], *args: Any, static_argnums: tuple[int, ...] = ()) -> FlopsResult

Count FLOPs for a function with given example arguments.

Parameters:

Name Type Description Default
fn Callable[..., Any]

JAX function to analyze.

required
*args Any

Example arguments for tracing.

()
static_argnums tuple[int, ...]

Argument indices to treat as static.

()

Returns:

Type Description
FlopsResult

FlopsResult with FLOP count and breakdown.

gpu ¤

GPU memory profiling and hardware-adaptive operations.

Provides hardware detection, shape optimization, GPU memory profiling (satisfying GPUProfilerProtocol), and memory usage analysis. Includes NVML-based GPU clock and power monitoring when pynvml is available.

PYNVML_AVAILABLE module-attribute ¤

PYNVML_AVAILABLE = True

logger module-attribute ¤

logger = getLogger(__name__)

HardwareConfig dataclass ¤

HardwareConfig(*, platform: str, precision: str, tile_size: int, critical_batch_size: int, memory_layout: str, use_vmem_optimization: bool)

Hardware-specific optimization configuration.

Attributes:

Name Type Description
platform str

Detected platform ("cpu", "tpu", "gpu_modern", "gpu_legacy").

precision str

Recommended floating-point precision string.

tile_size int

Tile size for matrix operation alignment.

critical_batch_size int

Optimal batch size for the platform.

memory_layout str

Memory layout preference.

use_vmem_optimization bool

Whether VMEM optimization is available.

platform instance-attribute ¤
platform: str
precision instance-attribute ¤
precision: str
tile_size instance-attribute ¤
tile_size: int
critical_batch_size instance-attribute ¤
critical_batch_size: int
memory_layout instance-attribute ¤
memory_layout: str
use_vmem_optimization instance-attribute ¤
use_vmem_optimization: bool

MemoryAnalysis dataclass ¤

MemoryAnalysis(*, baseline_memory_mb: float, peak_memory_mb: float, peak_usage_mb: float, retained_memory_mb: float, memory_efficiency: float, suggestions: tuple[str, ...] = ())

Result of pipeline memory analysis.

Attributes:

Name Type Description
baseline_memory_mb float

Memory usage before pipeline execution.

peak_memory_mb float

Memory usage at peak during execution.

peak_usage_mb float

Peak usage above baseline.

retained_memory_mb float

Memory retained after GC.

memory_efficiency float

Ratio of freed memory to peak usage.

suggestions tuple[str, ...]

Optimization suggestions.

baseline_memory_mb instance-attribute ¤
baseline_memory_mb: float
peak_memory_mb instance-attribute ¤
peak_memory_mb: float
peak_usage_mb instance-attribute ¤
peak_usage_mb: float
retained_memory_mb instance-attribute ¤
retained_memory_mb: float
memory_efficiency instance-attribute ¤
memory_efficiency: float
suggestions class-attribute instance-attribute ¤
suggestions: tuple[str, ...] = ()

AdaptiveOperation ¤

AdaptiveOperation()

Hardware-adaptive operations with auto-detection.

Detects the current JAX backend (CPU/GPU/TPU) and provides optimized configuration and shape padding.

config instance-attribute ¤
config = _detect_hardware()
optimize_shapes ¤
optimize_shapes(*shapes: tuple[int, ...]) -> list[tuple[int, ...]]

Pad tensor shapes to align with hardware tile size.

Parameters:

Name Type Description Default
*shapes tuple[int, ...]

Variable number of tensor shapes to optimize.

()

Returns:

Type Description
list[tuple[int, ...]]

List of optimized shapes padded to tile_size multiples.

GPUMemoryProfiler ¤

GPUMemoryProfiler()

GPU memory profiling satisfying GPUProfilerProtocol.

Uses multi-fallback strategy: memory_stats -> xla_bridge -> zeros.

has_gpu instance-attribute ¤
has_gpu = len(devices('gpu')) > 0
get_memory_usage ¤
get_memory_usage() -> dict[str, float]

Get current GPU memory usage statistics.

Returns:

Type Description
dict[str, float]

Dictionary with gpu_memory_used_mb, gpu_memory_total_mb,

dict[str, float]

and optionally gpu_memory_utilization.

get_utilization ¤
get_utilization() -> float

Get GPU utilization percentage for ResourceMonitor.

Returns:

Type Description
float

GPU memory utilization as percentage (0-100), or 0.0.

get_clock_info ¤
get_clock_info() -> dict[str, float]

Get current GPU clock frequencies via NVML.

Returns:

Type Description
dict[str, float]

Dictionary with 'gpu_clock_mhz' and 'mem_clock_mhz' keys.

dict[str, float]

Returns zeros if NVML is unavailable or query fails.

get_power_info ¤
get_power_info() -> dict[str, float]

Get current GPU power draw and limit via NVML.

Returns:

Type Description
dict[str, float]

Dictionary with 'power_draw_w' and 'power_limit_w' keys.

dict[str, float]

Returns zeros if NVML is unavailable or query fails.

analyze_memory_pattern ¤
analyze_memory_pattern(measurements: list[dict[str, float]]) -> list[str]

Analyze memory usage patterns and suggest optimizations.

Parameters:

Name Type Description Default
measurements list[dict[str, float]]

List of memory usage dictionaries.

required

Returns:

Type Description
list[str]

List of optimization suggestion strings.

MemoryOptimizer ¤

Memory optimization analysis for pipeline functions.

analyze_pipeline_memory ¤
analyze_pipeline_memory(pipeline_fn: Callable[[Any], Any], sample_data: Any) -> MemoryAnalysis | None

Analyze memory usage of a pipeline function.

Parameters:

Name Type Description Default
pipeline_fn Callable[[Any], Any]

Function to analyze.

required
sample_data Any

Sample input data.

required

Returns:

Type Description
MemoryAnalysis | None

MemoryAnalysis with measurements and suggestions,

MemoryAnalysis | None

or None if the pipeline raises an exception.

hardware ¤

Hardware specifications and detection for profiling.

Provides accelerator specs (TPU v5e, A100, H100, CPU) and utility functions for hardware detection and synchronized execution timing.

HARDWARE_SPECS module-attribute ¤

HARDWARE_SPECS: dict[str, dict[str, Any]] = {'tpu_v5e': {'peak_flops': 197000000000000.0, 'peak_flops_bf16': 197000000000000.0, 'memory_bandwidth': 1600000000000.0, 'critical_intensity': 123.125}, 'a100_80g': {'peak_flops': 312000000000000.0, 'peak_flops_bf16': 312000000000000.0, 'memory_bandwidth': 2039000000000.0, 'critical_intensity': 153.0, 'tensor_core_shapes': [(16, 16, 16), (16, 16, 8)]}, 'h100': {'peak_flops': 989000000000000.0, 'peak_flops_bf16': 989000000000000.0, 'memory_bandwidth': 3350000000000.0, 'critical_intensity': 295.0, 'tensor_core_shapes': [(16, 16, 16)]}, 'cpu_generic': {'peak_flops': 2000000000000.0, 'peak_flops_bf16': 2000000000000.0, 'memory_bandwidth': 200000000000.0, 'critical_intensity': 10.0, 'simd_width': 8}}

detect_hardware_specs ¤

detect_hardware_specs() -> dict[str, Any]

Detect current hardware and return appropriate specifications.

Uses jax.default_backend() to determine the accelerator type and returns pre-configured specs for that platform.

Returns:

Type Description
dict[str, Any]

Hardware specification dictionary with peak_flops, memory_bandwidth,

dict[str, Any]

and critical_intensity keys (among others).

measure_execution_time ¤

measure_execution_time(func: Callable[..., Any], inputs: list[Array], warmup: int = 3, iterations: int = 10) -> float

Measure execution time of a JAX function with synchronization.

JIT-compiles the function, runs warmup iterations, then times iterations executions with block_until_ready() barriers.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to benchmark.

required
inputs list[Array]

Input arguments as a list of arrays.

required
warmup int

Number of warmup iterations (for JIT compilation).

3
iterations int

Number of timed iterations.

10

Returns:

Type Description
float

Average execution time in seconds.

resources ¤

Background resource monitoring with 10Hz sampling.

Provides ResourceMonitor context manager for tracking CPU, memory, and optional GPU utilization during benchmark execution.

GPUProfilerProtocol ¤

Bases: Protocol

Protocol for GPU profilers providing utilization and memory data.

get_utilization ¤
get_utilization() -> float

Get current GPU utilization percentage.

Returns:

Type Description
float

GPU utilization as a percentage (0-100).

get_memory_usage ¤
get_memory_usage() -> dict[str, float]

Get current GPU memory usage statistics.

Returns:

Type Description
dict[str, float]

Dictionary with at least 'gpu_memory_used_mb' key.

get_clock_info ¤
get_clock_info() -> dict[str, float]

Get current GPU clock frequencies.

Returns:

Type Description
dict[str, float]

Dictionary with 'gpu_clock_mhz' and 'mem_clock_mhz' keys.

get_power_info ¤
get_power_info() -> dict[str, float]

Get current GPU power draw and limits.

Returns:

Type Description
dict[str, float]

Dictionary with 'power_draw_w' and 'power_limit_w' keys.

ResourceSample dataclass ¤

ResourceSample(*, timestamp: float, cpu_percent: float, rss_mb: float, gpu_util: float | None, gpu_mem_mb: float | None, gpu_clock_mhz: float | None = None, gpu_power_w: float | None = None)

Single resource measurement at a point in time.

Attributes:

Name Type Description
timestamp float

Time of measurement (perf_counter).

cpu_percent float

CPU utilization percentage.

rss_mb float

Resident set size in MB.

gpu_util float | None

GPU utilization percentage (None if no GPU).

gpu_mem_mb float | None

GPU memory used in MB (None if no GPU).

timestamp instance-attribute ¤
timestamp: float
cpu_percent instance-attribute ¤
cpu_percent: float
rss_mb instance-attribute ¤
rss_mb: float
gpu_util instance-attribute ¤
gpu_util: float | None
gpu_mem_mb instance-attribute ¤
gpu_mem_mb: float | None
gpu_clock_mhz class-attribute instance-attribute ¤
gpu_clock_mhz: float | None = None
gpu_power_w class-attribute instance-attribute ¤
gpu_power_w: float | None = None

ResourceSummary dataclass ¤

ResourceSummary(*, peak_rss_mb: float, mean_rss_mb: float, peak_gpu_mem_mb: float | None, mean_gpu_util: float | None, memory_growth_mb: float, num_samples: int, duration_sec: float, mean_gpu_clock_mhz: float | None = None, mean_gpu_power_w: float | None = None)

Aggregated resource usage over a monitoring period.

Attributes:

Name Type Description
peak_rss_mb float

Maximum RSS observed.

mean_rss_mb float

Average RSS across all samples.

peak_gpu_mem_mb float | None

Maximum GPU memory (None if no GPU).

mean_gpu_util float | None

Average GPU utilization (None if no GPU).

memory_growth_mb float

Last RSS minus first RSS (positive = growth).

num_samples int

Total samples collected.

duration_sec float

Time span of monitoring.

peak_rss_mb instance-attribute ¤
peak_rss_mb: float
mean_rss_mb instance-attribute ¤
mean_rss_mb: float
peak_gpu_mem_mb instance-attribute ¤
peak_gpu_mem_mb: float | None
mean_gpu_util instance-attribute ¤
mean_gpu_util: float | None
memory_growth_mb instance-attribute ¤
memory_growth_mb: float
num_samples instance-attribute ¤
num_samples: int
duration_sec instance-attribute ¤
duration_sec: float
mean_gpu_clock_mhz class-attribute instance-attribute ¤
mean_gpu_clock_mhz: float | None = None
mean_gpu_power_w class-attribute instance-attribute ¤
mean_gpu_power_w: float | None = None
to_dict ¤
to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

Optional GPU fields are included only when not None. Numeric values are converted to Python primitives for JAX scalar safety.

Returns:

Type Description
dict[str, Any]

Dictionary representation with all resource summary fields.

from_dict classmethod ¤
from_dict(data: dict[str, Any]) -> ResourceSummary

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with resource summary fields.

required

Returns:

Type Description
ResourceSummary

Reconstructed ResourceSummary instance.

ResourceMonitor ¤

ResourceMonitor(sample_interval_sec: float = 0.1, gpu_profiler: GPUProfilerProtocol | None = None)

Background 10Hz resource sampling via context manager.

Usage:

with ResourceMonitor() as mon:
    # ... run benchmark ...
summary = mon.summary

Parameters:

Name Type Description Default
sample_interval_sec float

Seconds between samples (default 0.1 = 10Hz).

0.1
gpu_profiler GPUProfilerProtocol | None

Optional profiler satisfying GPUProfilerProtocol.

None

Parameters:

Name Type Description Default
sample_interval_sec float

Seconds between resource samples.

0.1
gpu_profiler GPUProfilerProtocol | None

Optional GPU profiler for GPU metrics.

None
samples property ¤
samples: list[ResourceSample]

Return a copy of all collected samples.

summary property ¤
summary: ResourceSummary

Compute aggregated summary from collected samples.

Returns:

Type Description
ResourceSummary

ResourceSummary with aggregated metrics, or zeroed summary

ResourceSummary

if no samples were collected.

roofline ¤

Roofline analysis for JAX operations.

Identifies whether operations are compute-bound or memory-bound by comparing arithmetic intensity against hardware roofline limits, and generates optimization recommendations.

logger module-attribute ¤

logger = getLogger(__name__)

RooflineResult dataclass ¤

RooflineResult(*, arithmetic_intensity: float, critical_intensity: float, memory_bandwidth_utilization: float, flops_utilization: float, bottleneck: str, efficiency: float, execution_time_ms: float, recommendations: tuple[str, ...] = ())

Result of a roofline analysis on a JAX operation.

Attributes:

Name Type Description
arithmetic_intensity float

Achieved FLOPs per byte of memory traffic.

critical_intensity float

Hardware's ridge point (FLOPs/byte).

memory_bandwidth_utilization float

Fraction of peak memory bandwidth used.

flops_utilization float

Fraction of peak FLOPs achieved.

bottleneck str

Either "memory_bandwidth" or "compute".

efficiency float

Utilization of the binding resource.

execution_time_ms float

Measured execution time in milliseconds.

recommendations tuple[str, ...]

Optimization suggestions.

arithmetic_intensity instance-attribute ¤
arithmetic_intensity: float
critical_intensity instance-attribute ¤
critical_intensity: float
memory_bandwidth_utilization instance-attribute ¤
memory_bandwidth_utilization: float
flops_utilization instance-attribute ¤
flops_utilization: float
bottleneck instance-attribute ¤
bottleneck: str
efficiency instance-attribute ¤
efficiency: float
execution_time_ms instance-attribute ¤
execution_time_ms: float
recommendations class-attribute instance-attribute ¤
recommendations: tuple[str, ...] = ()
to_dict ¤
to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤
from_dict(data: dict[str, Any]) -> RooflineResult

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with roofline result fields.

required

Returns:

Type Description
RooflineResult

Reconstructed RooflineResult instance.

RooflineAnalyzer dataclass ¤

RooflineAnalyzer(*, hardware_specs: dict[str, Any] = detect_hardware_specs())

Analyzes operation performance against hardware roofline limits.

Uses measured execution time and estimated FLOPs to determine whether an operation is compute-bound or memory-bound, and how efficiently it uses the available hardware resources.

Attributes:

Name Type Description
hardware_specs dict[str, Any]

Hardware specification dictionary (auto-detected if not provided).

hardware_specs class-attribute instance-attribute ¤
hardware_specs: dict[str, Any] = field(default_factory=detect_hardware_specs)
analyze_operation ¤
analyze_operation(func: Callable[..., Any], inputs: list[Array], *, flops_override: int | None = None) -> RooflineResult

Perform roofline analysis on a JAX operation.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to analyze.

required
inputs list[Array]

Input arrays for the function.

required
flops_override int | None

If provided, use this FLOP count instead of estimating. For accurate results, pass the output of FlopsCounter.count().

None

Returns:

Type Description
RooflineResult

RooflineResult with bottleneck classification and recommendations.

timing ¤

Framework-agnostic timing with configurable result synchronization.

Provides TimingSample (frozen dataclass) and TimingCollector for measuring iteration throughput with per-batch timing breakdown. Uses time.perf_counter() exclusively for accurate benchmarking. Supports warm-up iteration exclusion and JIT compilation time measurement.

TimingSample dataclass ¤

TimingSample(*, wall_clock_sec: float, per_batch_times: tuple[float, ...], first_batch_time: float, num_batches: int, num_elements: int, compilation_time_sec: float | None = None, warmup_batches_excluded: int = 0)

Result of timing an iteration through a data pipeline.

Attributes:

Name Type Description
wall_clock_sec float

Total wall-clock time for the iteration.

per_batch_times tuple[float, ...]

Per-batch durations in seconds (warmup batches excluded).

first_batch_time float

Time from iteration start to first batch completion.

num_batches int

Number of batches consumed (including warmup).

num_elements int

Total elements processed (via count_fn).

compilation_time_sec float | None

JIT compilation time, if measured separately.

warmup_batches_excluded int

Number of leading batches excluded from per_batch_times.

wall_clock_sec instance-attribute ¤
wall_clock_sec: float
per_batch_times instance-attribute ¤
per_batch_times: tuple[float, ...]
first_batch_time instance-attribute ¤
first_batch_time: float
num_batches instance-attribute ¤
num_batches: int
num_elements instance-attribute ¤
num_elements: int
compilation_time_sec class-attribute instance-attribute ¤
compilation_time_sec: float | None = None
warmup_batches_excluded class-attribute instance-attribute ¤
warmup_batches_excluded: int = 0
to_dict ¤
to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤
from_dict(data: dict[str, Any]) -> TimingSample

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with TimingSample fields.

required

Returns:

Type Description
TimingSample

Reconstructed TimingSample instance.

TimingCollector ¤

TimingCollector(sync_fn: Callable[[Any], object] | None = None, warmup_iterations: int = 0)

Framework-agnostic timing with configurable GPU sync support.

Uses time.perf_counter() exclusively for accurate benchmarking. Supports configurable result synchronization via sync_fn and warm-up iteration exclusion for JIT-compiled workloads.

JAX dispatches operations asynchronously -- the host returns immediately while the device is still computing. Without an explicit synchronization barrier, perf_counter measures only host-side dispatch latency, not actual compute time. Pass a sync_fn that calls block_until_ready() on the workload result to force the host to wait for device completion before recording the timestamp.

Example -- JAX GPU timing with warm-up:

import jax.numpy as jnp

def run_step(batch):
    return jax.jit(step_fn)(batch)

collector = TimingCollector(
    sync_fn=lambda result: result.block_until_ready(),
    warmup_iterations=2,
)
sample = collector.measure_iteration(data_iter, num_batches=50, process_fn=run_step)
# sample.per_batch_times excludes the first 2 batches

Parameters:

Name Type Description Default
sync_fn Callable[[Any], object] | None

Synchronization function called with each batch result. For JAX: lambda result: result.block_until_ready() For PyTorch: lambda _: torch.cuda.synchronize() For CPU-only: None (default, no-op)

None
warmup_iterations int

Number of initial batches to exclude from per_batch_times statistics. They are still executed (important for JIT warm-up) but omitted from the timing result. Default: 0.

0

Parameters:

Name Type Description Default
sync_fn Callable[[Any], object] | None

Synchronization function called with each batch result.

None
warmup_iterations int

Number of initial batches to exclude from timing stats.

0
measure_iteration ¤
measure_iteration(iterator: Iterator[Any], num_batches: int | None = None, process_fn: Callable[[Any], Any] | None = None, count_fn: Callable[[Any], int] | None = None) -> TimingSample

Measure timing for batches from an iterator.

Warm-up batches (if configured) are executed but excluded from per_batch_times. wall_clock_sec covers the entire run including warm-up. num_batches reflects total batches consumed.

Parameters:

Name Type Description Default
iterator Iterator[Any]

Data iterator to measure.

required
num_batches int | None

Max batches to consume (including warmup). None exhausts iterator.

None
process_fn Callable[[Any], Any] | None

Optional per-batch function whose execution is timed. Defaults to identity (the yielded batch is treated as result).

None
count_fn Callable[[Any], int] | None

Function to count elements per batch. Default: 1 per batch.

None

Returns:

Type Description
TimingSample

TimingSample with timing measurements.

measure_compilation_time ¤
measure_compilation_time(fn: Callable[..., Any], *args: Any) -> float

Measure JIT compilation time for a JAX function.

Calls jax.jit(fn).lower(*args).compile() and times it. This measures the XLA compilation step only, not execution.

Parameters:

Name Type Description Default
fn Callable[..., Any]

JAX function to compile.

required
*args Any

Example arguments for lowering.

()

Returns:

Type Description
float

Compilation time in seconds.

tracing ¤

XLA trace linking for connecting JAX profiler output to benchmark runs.

Provides a simple context manager wrapping jax.profiler.trace() that records the trace file path for association with Store run metadata. Does not parse trace files — only links file paths to benchmark results.

logger module-attribute ¤

logger = getLogger(__name__)

TraceReference dataclass ¤

TraceReference(*, trace_dir: str, run_id: str | None = None)

Reference to a JAX profiler trace output.

Attributes:

Name Type Description
trace_dir str

Directory where the trace files were written.

run_id str | None

Optional benchmark run ID to link the trace to.

trace_dir instance-attribute ¤
trace_dir: str
run_id class-attribute instance-attribute ¤
run_id: str | None = None
to_dict ¤
to_dict() -> dict[str, Any]

Serialize to a JSON-compatible dictionary.

from_dict classmethod ¤
from_dict(data: dict[str, Any]) -> TraceReference

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with trace reference fields.

required

Returns:

Type Description
TraceReference

Reconstructed TraceReference instance.

TraceLinker ¤

Links JAX profiler traces to benchmark runs.

Usage:

linker = TraceLinker()
with linker.trace("/tmp/my_trace") as ref:
    # ... run workload ...
print(ref.trace_dir)  # "/tmp/my_trace"
trace ¤
trace(log_dir: str | Path, *, run_id: str | None = None, create_perfetto_link: bool = False, create_perfetto_trace: bool = False) -> Any

Start an XLA profiling session and record output metadata.

Wraps jax.profiler.trace() and records the output directory path as a TraceReference for downstream Store linkage.

Parameters:

Name Type Description Default
log_dir str | Path

Directory for trace output files.

required
run_id str | None

Optional benchmark run ID to associate with the trace.

None
create_perfetto_link bool

Whether to create a Perfetto link (passed to JAX).

False
create_perfetto_trace bool

Whether to create a Perfetto trace (passed to JAX).

False

Yields:

Type Description
Any

TraceReference with the trace directory and optional run ID.