Profiler¤
GPU memory profiling, hardware-adaptive optimization, and memory analysis for Datarax pipelines.
See Also¤
- Benchmarking Overview - All benchmarking tools
- Performance Tools - Optimization
- Benchmarking Guide
Overview¤
This module provides three components:
- GPUMemoryProfiler — Detects GPU availability and reports memory usage (used/total/utilization). Also analyzes memory patterns across multiple measurements to detect leaks and high utilization.
- MemoryOptimizer — Analyzes a pipeline function's memory footprint by measuring baseline, peak, and post-GC memory. Returns optimization suggestions.
- AdaptiveOperation — Auto-detects hardware (CPU/GPU/TPU) and configures optimal tile sizes, precision, and batch sizes. Also provides Grain auto-optimization.
Quick Start¤
Check GPU memory¤
from calibrax.profiling import GPUMemoryProfiler
profiler = GPUMemoryProfiler()
usage = profiler.get_memory_usage()
print(f"GPU memory: {usage['gpu_memory_used_mb']:.1f} / {usage['gpu_memory_total_mb']:.1f} MB")
print(f"Utilization: {usage.get('gpu_memory_utilization', 0):.1%}")
Analyze pipeline memory¤
from calibrax.profiling import MemoryOptimizer
optimizer = MemoryOptimizer()
analysis = optimizer.analyze_pipeline_memory(pipeline_fn, sample_data)
print(f"Peak usage: {analysis['peak_usage_mb']:.1f} MB")
print(f"Memory efficiency: {analysis['memory_efficiency']:.1%}")
for suggestion in analysis["suggestions"]:
print(f" - {suggestion}")
calibrax.profiling ¤
Profiling: timing, resources, GPU, energy, FLOPs, hardware, roofline, compilation, complexity.
HARDWARE_SPECS
module-attribute
¤
HARDWARE_SPECS: dict[str, dict[str, Any]] = {'tpu_v5e': {'peak_flops': 197000000000000.0, 'peak_flops_bf16': 197000000000000.0, 'memory_bandwidth': 1600000000000.0, 'critical_intensity': 123.125}, 'a100_80g': {'peak_flops': 312000000000000.0, 'peak_flops_bf16': 312000000000000.0, 'memory_bandwidth': 2039000000000.0, 'critical_intensity': 153.0, 'tensor_core_shapes': [(16, 16, 16), (16, 16, 8)]}, 'h100': {'peak_flops': 989000000000000.0, 'peak_flops_bf16': 989000000000000.0, 'memory_bandwidth': 3350000000000.0, 'critical_intensity': 295.0, 'tensor_core_shapes': [(16, 16, 16)]}, 'cpu_generic': {'peak_flops': 2000000000000.0, 'peak_flops_bf16': 2000000000000.0, 'memory_bandwidth': 200000000000.0, 'critical_intensity': 10.0, 'simd_width': 8}}
CompilationProfiler ¤
Analyzes JAX JIT compilation performance and optimization.
Instruments JIT-compiled functions to track compilation cache hits/misses,
compilation times, and input shape consistency. Use profile_jit_compilation
to wrap a function, then call get_result() for aggregated analysis.
profile_jit_compilation ¤
Create an instrumented wrapper that profiles JIT compilation.
The returned callable tracks cache hits/misses, compilation times, and input shape patterns. Results accumulate in this profiler instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to instrument. |
required |
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
Instrumented function with identical signature. |
get_result ¤
get_result() -> CompilationResult
Get aggregated compilation profiling results.
Returns:
| Type | Description |
|---|---|
CompilationResult
|
CompilationResult with cache statistics, timing, and recommendations. |
estimate_xla_optimization ¤
estimate_xla_optimization(func: Callable[..., Any], *sample_args: Any) -> XLAOptimizationResult
Estimate XLA optimization effectiveness by analyzing HLO text.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to analyze. |
required |
*sample_args
|
Any
|
Example arguments for lowering/compiling. |
()
|
Returns:
| Type | Description |
|---|---|
XLAOptimizationResult
|
XLAOptimizationResult with HLO analysis metrics. |
CompilationResult
dataclass
¤
CompilationResult(*, cache_hit_rate: float, total_calls: int, cache_hits: int, cache_misses: int, avg_compilation_time_ms: float, max_compilation_time_ms: float, unique_signatures: int, health_score: float, health_level: str, recommendations: tuple[str, ...] = ())
Result of compilation profiling analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
cache_hit_rate |
float
|
Fraction of calls that hit the compilation cache. |
total_calls |
int
|
Total number of profiled function calls. |
cache_hits |
int
|
Number of cache hits. |
cache_misses |
int
|
Number of cache misses (triggering compilation). |
avg_compilation_time_ms |
float
|
Average compilation time in milliseconds. |
max_compilation_time_ms |
float
|
Maximum compilation time in milliseconds. |
unique_signatures |
int
|
Number of unique function signatures compiled. |
health_score |
float
|
Overall compilation health score (0-1). |
health_level |
str
|
Human-readable health level. |
recommendations |
tuple[str, ...]
|
Optimization recommendations. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> CompilationResult
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with compilation result fields. |
required |
Returns:
| Type | Description |
|---|---|
CompilationResult
|
Reconstructed CompilationResult instance. |
XLAOptimizationResult
dataclass
¤
XLAOptimizationResult(*, optimization_score: float, fusion_ratio: float, arithmetic_ratio: float, memory_ratio: float, total_kernels: int, recommendations: tuple[str, ...] = ())
Result of XLA optimization effectiveness analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
optimization_score |
float
|
Overall optimization score (0-1). |
fusion_ratio |
float
|
Fraction of fused kernels. |
arithmetic_ratio |
float
|
Fraction of arithmetic operations. |
memory_ratio |
float
|
Fraction of memory operations. |
total_kernels |
int
|
Total number of HLO kernels. |
recommendations |
tuple[str, ...]
|
Optimization recommendations. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> XLAOptimizationResult
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with XLA optimization result fields. |
required |
Returns:
| Type | Description |
|---|---|
XLAOptimizationResult
|
Reconstructed XLAOptimizationResult instance. |
ComplexityResult
dataclass
¤
ComplexityResult(*, total_parameters: int, parameter_memory_mb: float, largest_layer_name: str, largest_layer_params: int, input_shape: tuple[int, ...], estimated_memory_mb: float, total_estimated_operations: int, dominant_complexity: str, scaling_characteristics: dict[str, str] = dict())
Result of model complexity analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
total_parameters |
int
|
Total number of trainable parameters. |
parameter_memory_mb |
float
|
Memory consumed by parameters (float32). |
largest_layer_name |
str
|
Name of the layer with the most parameters. |
largest_layer_params |
int
|
Parameter count of the largest layer. |
input_shape |
tuple[int, ...]
|
Shape of the analyzed input. |
estimated_memory_mb |
float
|
Estimated total memory (params + activations). |
total_estimated_operations |
int
|
Estimated total operations count. |
dominant_complexity |
str
|
Name of the dominant operation type. |
scaling_characteristics |
dict[str, str]
|
Mapping of operation type to complexity class. |
scaling_characteristics
class-attribute
instance-attribute
¤
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> ComplexityResult
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with complexity result fields. |
required |
Returns:
| Type | Description |
|---|---|
ComplexityResult
|
Reconstructed ComplexityResult instance. |
EnergyMonitor ¤
EnergyMonitor(sample_interval_sec: float = 0.1)
Background energy monitoring via NVML and RAPL.
Uses daemon thread sampling at configurable interval. Gracefully degrades when NVML or RAPL is unavailable.
Usage:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_interval_sec
|
float
|
Seconds between energy samples. |
0.1
|
summary
property
¤
summary: EnergySummary
Compute aggregated energy summary.
Returns:
| Type | Description |
|---|---|
EnergySummary
|
EnergySummary with totals, or None fields when unavailable. |
EnergySample
dataclass
¤
EnergySample(*, timestamp: float, gpu_power_watts: float | None, cpu_energy_joules: float | None, gpu_energy_joules: float | None)
Single energy measurement at a point in time.
Attributes:
| Name | Type | Description |
|---|---|---|
timestamp |
float
|
Time of measurement (perf_counter). |
gpu_power_watts |
float | None
|
Instantaneous GPU power (None if unavailable). |
cpu_energy_joules |
float | None
|
Cumulative CPU energy since monitoring start. |
gpu_energy_joules |
float | None
|
Cumulative GPU energy since monitoring start. |
EnergySummary
dataclass
¤
EnergySummary(*, total_gpu_energy_joules: float | None, total_cpu_energy_joules: float | None, total_combined_energy_joules: float | None, mean_gpu_power_watts: float | None, peak_gpu_power_watts: float | None, duration_sec: float, num_samples: int)
Aggregated energy usage over a monitoring period.
Attributes:
| Name | Type | Description |
|---|---|---|
total_gpu_energy_joules |
float | None
|
Total GPU energy consumed. |
total_cpu_energy_joules |
float | None
|
Total CPU energy consumed. |
total_combined_energy_joules |
float | None
|
GPU + CPU combined. |
mean_gpu_power_watts |
float | None
|
Average GPU power draw. |
peak_gpu_power_watts |
float | None
|
Maximum GPU power draw. |
duration_sec |
float
|
Monitoring duration. |
num_samples |
int
|
Total samples collected. |
FlopsCounter ¤
Count FLOPs of JAX functions via jaxpr analysis.
Uses jax.make_jaxpr to trace the function and counts FLOPs
for each primitive based on operation-specific rules.
For NNX models that use stochastic operations (dropout, etc.),
use flax.nnx.tabulate(model, *args, compute_flops=True)
instead — it handles NNX state management internally.
count ¤
Count FLOPs for a function with given example arguments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[..., Any]
|
JAX function to analyze. |
required |
*args
|
Any
|
Example arguments for tracing. |
()
|
static_argnums
|
tuple[int, ...]
|
Argument indices to treat as static. |
()
|
Returns:
| Type | Description |
|---|---|
FlopsResult
|
FlopsResult with FLOP count and breakdown. |
FlopsResult
dataclass
¤
FlopsResult(*, total_flops: int, flops_by_operation: dict[str, int], num_operations: int, function_name: str)
Result of FLOP counting for a function.
Attributes:
| Name | Type | Description |
|---|---|---|
total_flops |
int
|
Total estimated FLOPs. |
flops_by_operation |
dict[str, int]
|
Breakdown by primitive operation name. |
num_operations |
int
|
Number of JAX primitives in the trace. |
function_name |
str
|
Name of the analyzed function. |
AdaptiveOperation ¤
Hardware-adaptive operations with auto-detection.
Detects the current JAX backend (CPU/GPU/TPU) and provides optimized configuration and shape padding.
optimize_shapes ¤
GPUMemoryProfiler ¤
GPU memory profiling satisfying GPUProfilerProtocol.
Uses multi-fallback strategy: memory_stats -> xla_bridge -> zeros.
HardwareConfig
dataclass
¤
HardwareConfig(*, platform: str, precision: str, tile_size: int, critical_batch_size: int, memory_layout: str, use_vmem_optimization: bool)
Hardware-specific optimization configuration.
Attributes:
| Name | Type | Description |
|---|---|---|
platform |
str
|
Detected platform ("cpu", "tpu", "gpu_modern", "gpu_legacy"). |
precision |
str
|
Recommended floating-point precision string. |
tile_size |
int
|
Tile size for matrix operation alignment. |
critical_batch_size |
int
|
Optimal batch size for the platform. |
memory_layout |
str
|
Memory layout preference. |
use_vmem_optimization |
bool
|
Whether VMEM optimization is available. |
MemoryAnalysis
dataclass
¤
MemoryAnalysis(*, baseline_memory_mb: float, peak_memory_mb: float, peak_usage_mb: float, retained_memory_mb: float, memory_efficiency: float, suggestions: tuple[str, ...] = ())
Result of pipeline memory analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
baseline_memory_mb |
float
|
Memory usage before pipeline execution. |
peak_memory_mb |
float
|
Memory usage at peak during execution. |
peak_usage_mb |
float
|
Peak usage above baseline. |
retained_memory_mb |
float
|
Memory retained after GC. |
memory_efficiency |
float
|
Ratio of freed memory to peak usage. |
suggestions |
tuple[str, ...]
|
Optimization suggestions. |
MemoryOptimizer ¤
Memory optimization analysis for pipeline functions.
analyze_pipeline_memory ¤
analyze_pipeline_memory(pipeline_fn: Callable[[Any], Any], sample_data: Any) -> MemoryAnalysis | None
Analyze memory usage of a pipeline function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline_fn
|
Callable[[Any], Any]
|
Function to analyze. |
required |
sample_data
|
Any
|
Sample input data. |
required |
Returns:
| Type | Description |
|---|---|
MemoryAnalysis | None
|
MemoryAnalysis with measurements and suggestions, |
MemoryAnalysis | None
|
or None if the pipeline raises an exception. |
GPUProfilerProtocol ¤
ResourceMonitor ¤
ResourceMonitor(sample_interval_sec: float = 0.1, gpu_profiler: GPUProfilerProtocol | None = None)
Background 10Hz resource sampling via context manager.
Usage:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_interval_sec
|
float
|
Seconds between samples (default 0.1 = 10Hz). |
0.1
|
gpu_profiler
|
GPUProfilerProtocol | None
|
Optional profiler satisfying GPUProfilerProtocol. |
None
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_interval_sec
|
float
|
Seconds between resource samples. |
0.1
|
gpu_profiler
|
GPUProfilerProtocol | None
|
Optional GPU profiler for GPU metrics. |
None
|
summary
property
¤
summary: ResourceSummary
Compute aggregated summary from collected samples.
Returns:
| Type | Description |
|---|---|
ResourceSummary
|
ResourceSummary with aggregated metrics, or zeroed summary |
ResourceSummary
|
if no samples were collected. |
ResourceSample
dataclass
¤
ResourceSample(*, timestamp: float, cpu_percent: float, rss_mb: float, gpu_util: float | None, gpu_mem_mb: float | None, gpu_clock_mhz: float | None = None, gpu_power_w: float | None = None)
Single resource measurement at a point in time.
Attributes:
| Name | Type | Description |
|---|---|---|
timestamp |
float
|
Time of measurement (perf_counter). |
cpu_percent |
float
|
CPU utilization percentage. |
rss_mb |
float
|
Resident set size in MB. |
gpu_util |
float | None
|
GPU utilization percentage (None if no GPU). |
gpu_mem_mb |
float | None
|
GPU memory used in MB (None if no GPU). |
ResourceSummary
dataclass
¤
ResourceSummary(*, peak_rss_mb: float, mean_rss_mb: float, peak_gpu_mem_mb: float | None, mean_gpu_util: float | None, memory_growth_mb: float, num_samples: int, duration_sec: float, mean_gpu_clock_mhz: float | None = None, mean_gpu_power_w: float | None = None)
Aggregated resource usage over a monitoring period.
Attributes:
| Name | Type | Description |
|---|---|---|
peak_rss_mb |
float
|
Maximum RSS observed. |
mean_rss_mb |
float
|
Average RSS across all samples. |
peak_gpu_mem_mb |
float | None
|
Maximum GPU memory (None if no GPU). |
mean_gpu_util |
float | None
|
Average GPU utilization (None if no GPU). |
memory_growth_mb |
float
|
Last RSS minus first RSS (positive = growth). |
num_samples |
int
|
Total samples collected. |
duration_sec |
float
|
Time span of monitoring. |
to_dict ¤
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> ResourceSummary
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with resource summary fields. |
required |
Returns:
| Type | Description |
|---|---|
ResourceSummary
|
Reconstructed ResourceSummary instance. |
RooflineAnalyzer
dataclass
¤
RooflineAnalyzer(*, hardware_specs: dict[str, Any] = detect_hardware_specs())
Analyzes operation performance against hardware roofline limits.
Uses measured execution time and estimated FLOPs to determine whether an operation is compute-bound or memory-bound, and how efficiently it uses the available hardware resources.
Attributes:
| Name | Type | Description |
|---|---|---|
hardware_specs |
dict[str, Any]
|
Hardware specification dictionary (auto-detected if not provided). |
hardware_specs
class-attribute
instance-attribute
¤
hardware_specs: dict[str, Any] = field(default_factory=detect_hardware_specs)
analyze_operation ¤
analyze_operation(func: Callable[..., Any], inputs: list[Array], *, flops_override: int | None = None) -> RooflineResult
Perform roofline analysis on a JAX operation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to analyze. |
required |
inputs
|
list[Array]
|
Input arrays for the function. |
required |
flops_override
|
int | None
|
If provided, use this FLOP count instead of estimating.
For accurate results, pass the output of |
None
|
Returns:
| Type | Description |
|---|---|
RooflineResult
|
RooflineResult with bottleneck classification and recommendations. |
RooflineResult
dataclass
¤
RooflineResult(*, arithmetic_intensity: float, critical_intensity: float, memory_bandwidth_utilization: float, flops_utilization: float, bottleneck: str, efficiency: float, execution_time_ms: float, recommendations: tuple[str, ...] = ())
Result of a roofline analysis on a JAX operation.
Attributes:
| Name | Type | Description |
|---|---|---|
arithmetic_intensity |
float
|
Achieved FLOPs per byte of memory traffic. |
critical_intensity |
float
|
Hardware's ridge point (FLOPs/byte). |
memory_bandwidth_utilization |
float
|
Fraction of peak memory bandwidth used. |
flops_utilization |
float
|
Fraction of peak FLOPs achieved. |
bottleneck |
str
|
Either "memory_bandwidth" or "compute". |
efficiency |
float
|
Utilization of the binding resource. |
execution_time_ms |
float
|
Measured execution time in milliseconds. |
recommendations |
tuple[str, ...]
|
Optimization suggestions. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> RooflineResult
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with roofline result fields. |
required |
Returns:
| Type | Description |
|---|---|
RooflineResult
|
Reconstructed RooflineResult instance. |
TimingCollector ¤
Framework-agnostic timing with configurable GPU sync support.
Uses time.perf_counter() exclusively for accurate benchmarking. Supports configurable result synchronization via sync_fn and warm-up iteration exclusion for JIT-compiled workloads.
JAX dispatches operations asynchronously -- the host returns
immediately while the device is still computing. Without an explicit
synchronization barrier, perf_counter measures only host-side
dispatch latency, not actual compute time. Pass a sync_fn that
calls block_until_ready() on the workload result to force the host
to wait for device completion before recording the timestamp.
Example -- JAX GPU timing with warm-up:
import jax.numpy as jnp
def run_step(batch):
return jax.jit(step_fn)(batch)
collector = TimingCollector(
sync_fn=lambda result: result.block_until_ready(),
warmup_iterations=2,
)
sample = collector.measure_iteration(data_iter, num_batches=50, process_fn=run_step)
# sample.per_batch_times excludes the first 2 batches
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sync_fn
|
Callable[[Any], object] | None
|
Synchronization function called with each batch result.
For JAX: |
None
|
warmup_iterations
|
int
|
Number of initial batches to exclude from per_batch_times statistics. They are still executed (important for JIT warm-up) but omitted from the timing result. Default: 0. |
0
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sync_fn
|
Callable[[Any], object] | None
|
Synchronization function called with each batch result. |
None
|
warmup_iterations
|
int
|
Number of initial batches to exclude from timing stats. |
0
|
measure_iteration ¤
measure_iteration(iterator: Iterator[Any], num_batches: int | None = None, process_fn: Callable[[Any], Any] | None = None, count_fn: Callable[[Any], int] | None = None) -> TimingSample
Measure timing for batches from an iterator.
Warm-up batches (if configured) are executed but excluded from
per_batch_times. wall_clock_sec covers the entire run
including warm-up. num_batches reflects total batches consumed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
iterator
|
Iterator[Any]
|
Data iterator to measure. |
required |
num_batches
|
int | None
|
Max batches to consume (including warmup). None exhausts iterator. |
None
|
process_fn
|
Callable[[Any], Any] | None
|
Optional per-batch function whose execution is timed. Defaults to identity (the yielded batch is treated as result). |
None
|
count_fn
|
Callable[[Any], int] | None
|
Function to count elements per batch. Default: 1 per batch. |
None
|
Returns:
| Type | Description |
|---|---|
TimingSample
|
TimingSample with timing measurements. |
measure_compilation_time ¤
Measure JIT compilation time for a JAX function.
Calls jax.jit(fn).lower(*args).compile() and times it.
This measures the XLA compilation step only, not execution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[..., Any]
|
JAX function to compile. |
required |
*args
|
Any
|
Example arguments for lowering. |
()
|
Returns:
| Type | Description |
|---|---|
float
|
Compilation time in seconds. |
TimingSample
dataclass
¤
TimingSample(*, wall_clock_sec: float, per_batch_times: tuple[float, ...], first_batch_time: float, num_batches: int, num_elements: int, compilation_time_sec: float | None = None, warmup_batches_excluded: int = 0)
Result of timing an iteration through a data pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
wall_clock_sec |
float
|
Total wall-clock time for the iteration. |
per_batch_times |
tuple[float, ...]
|
Per-batch durations in seconds (warmup batches excluded). |
first_batch_time |
float
|
Time from iteration start to first batch completion. |
num_batches |
int
|
Number of batches consumed (including warmup). |
num_elements |
int
|
Total elements processed (via count_fn). |
compilation_time_sec |
float | None
|
JIT compilation time, if measured separately. |
warmup_batches_excluded |
int
|
Number of leading batches excluded from per_batch_times. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> TimingSample
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with TimingSample fields. |
required |
Returns:
| Type | Description |
|---|---|
TimingSample
|
Reconstructed TimingSample instance. |
TraceLinker ¤
Links JAX profiler traces to benchmark runs.
Usage:
linker = TraceLinker()
with linker.trace("/tmp/my_trace") as ref:
# ... run workload ...
print(ref.trace_dir) # "/tmp/my_trace"
trace ¤
trace(log_dir: str | Path, *, run_id: str | None = None, create_perfetto_link: bool = False, create_perfetto_trace: bool = False) -> Any
Start an XLA profiling session and record output metadata.
Wraps jax.profiler.trace() and records the output directory
path as a TraceReference for downstream Store linkage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_dir
|
str | Path
|
Directory for trace output files. |
required |
run_id
|
str | None
|
Optional benchmark run ID to associate with the trace. |
None
|
create_perfetto_link
|
bool
|
Whether to create a Perfetto link (passed to JAX). |
False
|
create_perfetto_trace
|
bool
|
Whether to create a Perfetto trace (passed to JAX). |
False
|
Yields:
| Type | Description |
|---|---|
Any
|
TraceReference with the trace directory and optional run ID. |
TraceReference
dataclass
¤
Reference to a JAX profiler trace output.
Attributes:
| Name | Type | Description |
|---|---|---|
trace_dir |
str
|
Directory where the trace files were written. |
run_id |
str | None
|
Optional benchmark run ID to link the trace to. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> TraceReference
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with trace reference fields. |
required |
Returns:
| Type | Description |
|---|---|
TraceReference
|
Reconstructed TraceReference instance. |
analyze_complexity ¤
analyze_complexity(model: Module, input_shape: tuple[int, ...]) -> ComplexityResult
Analyze complexity of a Flax NNX module.
Performs parameter counting, memory estimation, computational complexity analysis, and scaling characterization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Flax NNX model to analyze. |
required |
input_shape
|
tuple[int, ...]
|
Shape of input data (including batch dimension). |
required |
Returns:
| Type | Description |
|---|---|
ComplexityResult
|
ComplexityResult with detailed complexity metrics. |
detect_hardware_specs ¤
Detect current hardware and return appropriate specifications.
Uses jax.default_backend() to determine the accelerator type
and returns pre-configured specs for that platform.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Hardware specification dictionary with peak_flops, memory_bandwidth, |
dict[str, Any]
|
and critical_intensity keys (among others). |
measure_execution_time ¤
measure_execution_time(func: Callable[..., Any], inputs: list[Array], warmup: int = 3, iterations: int = 10) -> float
Measure execution time of a JAX function with synchronization.
JIT-compiles the function, runs warmup iterations, then times
iterations executions with block_until_ready() barriers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to benchmark. |
required |
inputs
|
list[Array]
|
Input arguments as a list of arrays. |
required |
warmup
|
int
|
Number of warmup iterations (for JIT compilation). |
3
|
iterations
|
int
|
Number of timed iterations. |
10
|
Returns:
| Type | Description |
|---|---|
float
|
Average execution time in seconds. |
carbon ¤
Carbon emissions tracking via CodeCarbon integration.
Wraps the codecarbon.EmissionsTracker as a context manager, exposing
emissions data as a frozen CarbonResult dataclass. Requires the
optional codecarbon dependency (uv pip install "calibrax[codecarbon]").
CarbonResult
dataclass
¤
CarbonResult(*, emissions_kg_co2: float, energy_consumed_kwh: float, duration_sec: float, country_iso_code: str | None = None)
Result of carbon emissions measurement.
Attributes:
| Name | Type | Description |
|---|---|---|
emissions_kg_co2 |
float
|
Total CO2 emissions in kilograms. |
energy_consumed_kwh |
float
|
Total energy consumed in kilowatt-hours. |
duration_sec |
float
|
Duration of the tracked period in seconds. |
country_iso_code |
str | None
|
ISO code of the country used for carbon intensity. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> CarbonResult
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with carbon result fields. |
required |
Returns:
| Type | Description |
|---|---|
CarbonResult
|
Reconstructed CarbonResult instance. |
CarbonTracker ¤
Context manager for tracking carbon emissions via CodeCarbon.
Requires the codecarbon package. Install with:
Usage:
with CarbonTracker() as tracker:
# ... run workload ...
result = tracker.result()
print(f"Emissions: {result.emissions_kg_co2:.6f} kg CO2")
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
country_iso_code
|
str | None
|
Optional ISO country code for regional carbon intensity. |
None
|
log_level
|
str
|
Logging level for CodeCarbon (default: "warning"). |
'warning'
|
Raises:
| Type | Description |
|---|---|
ImportError
|
If codecarbon is not installed. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
country_iso_code
|
str | None
|
Optional ISO country code. |
None
|
log_level
|
str
|
CodeCarbon logging level. |
'warning'
|
Raises:
| Type | Description |
|---|---|
ImportError
|
If codecarbon is not installed. |
result ¤
result() -> CarbonResult
Get the carbon emissions result.
Call this after exiting the context manager.
Returns:
| Type | Description |
|---|---|
CarbonResult
|
CarbonResult with emissions, energy, and duration data. |
compilation ¤
JIT compilation profiler for JAX.
Analyzes JIT compilation efficiency, cache hit rates, XLA optimization effectiveness, and provides recommendations for compilation optimization.
CompilationResult
dataclass
¤
CompilationResult(*, cache_hit_rate: float, total_calls: int, cache_hits: int, cache_misses: int, avg_compilation_time_ms: float, max_compilation_time_ms: float, unique_signatures: int, health_score: float, health_level: str, recommendations: tuple[str, ...] = ())
Result of compilation profiling analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
cache_hit_rate |
float
|
Fraction of calls that hit the compilation cache. |
total_calls |
int
|
Total number of profiled function calls. |
cache_hits |
int
|
Number of cache hits. |
cache_misses |
int
|
Number of cache misses (triggering compilation). |
avg_compilation_time_ms |
float
|
Average compilation time in milliseconds. |
max_compilation_time_ms |
float
|
Maximum compilation time in milliseconds. |
unique_signatures |
int
|
Number of unique function signatures compiled. |
health_score |
float
|
Overall compilation health score (0-1). |
health_level |
str
|
Human-readable health level. |
recommendations |
tuple[str, ...]
|
Optimization recommendations. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> CompilationResult
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with compilation result fields. |
required |
Returns:
| Type | Description |
|---|---|
CompilationResult
|
Reconstructed CompilationResult instance. |
XLAOptimizationResult
dataclass
¤
XLAOptimizationResult(*, optimization_score: float, fusion_ratio: float, arithmetic_ratio: float, memory_ratio: float, total_kernels: int, recommendations: tuple[str, ...] = ())
Result of XLA optimization effectiveness analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
optimization_score |
float
|
Overall optimization score (0-1). |
fusion_ratio |
float
|
Fraction of fused kernels. |
arithmetic_ratio |
float
|
Fraction of arithmetic operations. |
memory_ratio |
float
|
Fraction of memory operations. |
total_kernels |
int
|
Total number of HLO kernels. |
recommendations |
tuple[str, ...]
|
Optimization recommendations. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> XLAOptimizationResult
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with XLA optimization result fields. |
required |
Returns:
| Type | Description |
|---|---|
XLAOptimizationResult
|
Reconstructed XLAOptimizationResult instance. |
CompilationProfiler ¤
Analyzes JAX JIT compilation performance and optimization.
Instruments JIT-compiled functions to track compilation cache hits/misses,
compilation times, and input shape consistency. Use profile_jit_compilation
to wrap a function, then call get_result() for aggregated analysis.
profile_jit_compilation ¤
Create an instrumented wrapper that profiles JIT compilation.
The returned callable tracks cache hits/misses, compilation times, and input shape patterns. Results accumulate in this profiler instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to instrument. |
required |
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
Instrumented function with identical signature. |
get_result ¤
get_result() -> CompilationResult
Get aggregated compilation profiling results.
Returns:
| Type | Description |
|---|---|
CompilationResult
|
CompilationResult with cache statistics, timing, and recommendations. |
estimate_xla_optimization ¤
estimate_xla_optimization(func: Callable[..., Any], *sample_args: Any) -> XLAOptimizationResult
Estimate XLA optimization effectiveness by analyzing HLO text.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to analyze. |
required |
*sample_args
|
Any
|
Example arguments for lowering/compiling. |
()
|
Returns:
| Type | Description |
|---|---|
XLAOptimizationResult
|
XLAOptimizationResult with HLO analysis metrics. |
complexity ¤
Model complexity analysis for Flax NNX modules.
Provides parameter counts, memory usage estimates, computational complexity analysis, and scaling characteristics for any NNX module.
ComplexityResult
dataclass
¤
ComplexityResult(*, total_parameters: int, parameter_memory_mb: float, largest_layer_name: str, largest_layer_params: int, input_shape: tuple[int, ...], estimated_memory_mb: float, total_estimated_operations: int, dominant_complexity: str, scaling_characteristics: dict[str, str] = dict())
Result of model complexity analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
total_parameters |
int
|
Total number of trainable parameters. |
parameter_memory_mb |
float
|
Memory consumed by parameters (float32). |
largest_layer_name |
str
|
Name of the layer with the most parameters. |
largest_layer_params |
int
|
Parameter count of the largest layer. |
input_shape |
tuple[int, ...]
|
Shape of the analyzed input. |
estimated_memory_mb |
float
|
Estimated total memory (params + activations). |
total_estimated_operations |
int
|
Estimated total operations count. |
dominant_complexity |
str
|
Name of the dominant operation type. |
scaling_characteristics |
dict[str, str]
|
Mapping of operation type to complexity class. |
scaling_characteristics
class-attribute
instance-attribute
¤
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> ComplexityResult
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with complexity result fields. |
required |
Returns:
| Type | Description |
|---|---|
ComplexityResult
|
Reconstructed ComplexityResult instance. |
analyze_complexity ¤
analyze_complexity(model: Module, input_shape: tuple[int, ...]) -> ComplexityResult
Analyze complexity of a Flax NNX module.
Performs parameter counting, memory estimation, computational complexity analysis, and scaling characterization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Flax NNX model to analyze. |
required |
input_shape
|
tuple[int, ...]
|
Shape of input data (including batch dimension). |
required |
Returns:
| Type | Description |
|---|---|
ComplexityResult
|
ComplexityResult with detailed complexity metrics. |
energy ¤
Energy monitoring via NVML (GPU) and RAPL (CPU).
Provides EnergyMonitor context manager for tracking power consumption during benchmark execution. Gracefully degrades when hardware interfaces are unavailable.
EnergySample
dataclass
¤
EnergySample(*, timestamp: float, gpu_power_watts: float | None, cpu_energy_joules: float | None, gpu_energy_joules: float | None)
Single energy measurement at a point in time.
Attributes:
| Name | Type | Description |
|---|---|---|
timestamp |
float
|
Time of measurement (perf_counter). |
gpu_power_watts |
float | None
|
Instantaneous GPU power (None if unavailable). |
cpu_energy_joules |
float | None
|
Cumulative CPU energy since monitoring start. |
gpu_energy_joules |
float | None
|
Cumulative GPU energy since monitoring start. |
EnergySummary
dataclass
¤
EnergySummary(*, total_gpu_energy_joules: float | None, total_cpu_energy_joules: float | None, total_combined_energy_joules: float | None, mean_gpu_power_watts: float | None, peak_gpu_power_watts: float | None, duration_sec: float, num_samples: int)
Aggregated energy usage over a monitoring period.
Attributes:
| Name | Type | Description |
|---|---|---|
total_gpu_energy_joules |
float | None
|
Total GPU energy consumed. |
total_cpu_energy_joules |
float | None
|
Total CPU energy consumed. |
total_combined_energy_joules |
float | None
|
GPU + CPU combined. |
mean_gpu_power_watts |
float | None
|
Average GPU power draw. |
peak_gpu_power_watts |
float | None
|
Maximum GPU power draw. |
duration_sec |
float
|
Monitoring duration. |
num_samples |
int
|
Total samples collected. |
EnergyMonitor ¤
EnergyMonitor(sample_interval_sec: float = 0.1)
Background energy monitoring via NVML and RAPL.
Uses daemon thread sampling at configurable interval. Gracefully degrades when NVML or RAPL is unavailable.
Usage:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_interval_sec
|
float
|
Seconds between energy samples. |
0.1
|
summary
property
¤
summary: EnergySummary
Compute aggregated energy summary.
Returns:
| Type | Description |
|---|---|
EnergySummary
|
EnergySummary with totals, or None fields when unavailable. |
flops ¤
FLOP counting via JAX's jaxpr tracing.
Provides FlopsCounter for estimating FLOPs of JAX functions by analyzing their Jaxpr intermediate representation.
FlopsResult
dataclass
¤
FlopsResult(*, total_flops: int, flops_by_operation: dict[str, int], num_operations: int, function_name: str)
Result of FLOP counting for a function.
Attributes:
| Name | Type | Description |
|---|---|---|
total_flops |
int
|
Total estimated FLOPs. |
flops_by_operation |
dict[str, int]
|
Breakdown by primitive operation name. |
num_operations |
int
|
Number of JAX primitives in the trace. |
function_name |
str
|
Name of the analyzed function. |
FlopsCounter ¤
Count FLOPs of JAX functions via jaxpr analysis.
Uses jax.make_jaxpr to trace the function and counts FLOPs
for each primitive based on operation-specific rules.
For NNX models that use stochastic operations (dropout, etc.),
use flax.nnx.tabulate(model, *args, compute_flops=True)
instead — it handles NNX state management internally.
count ¤
Count FLOPs for a function with given example arguments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[..., Any]
|
JAX function to analyze. |
required |
*args
|
Any
|
Example arguments for tracing. |
()
|
static_argnums
|
tuple[int, ...]
|
Argument indices to treat as static. |
()
|
Returns:
| Type | Description |
|---|---|
FlopsResult
|
FlopsResult with FLOP count and breakdown. |
gpu ¤
GPU memory profiling and hardware-adaptive operations.
Provides hardware detection, shape optimization, GPU memory profiling (satisfying GPUProfilerProtocol), and memory usage analysis. Includes NVML-based GPU clock and power monitoring when pynvml is available.
HardwareConfig
dataclass
¤
HardwareConfig(*, platform: str, precision: str, tile_size: int, critical_batch_size: int, memory_layout: str, use_vmem_optimization: bool)
Hardware-specific optimization configuration.
Attributes:
| Name | Type | Description |
|---|---|---|
platform |
str
|
Detected platform ("cpu", "tpu", "gpu_modern", "gpu_legacy"). |
precision |
str
|
Recommended floating-point precision string. |
tile_size |
int
|
Tile size for matrix operation alignment. |
critical_batch_size |
int
|
Optimal batch size for the platform. |
memory_layout |
str
|
Memory layout preference. |
use_vmem_optimization |
bool
|
Whether VMEM optimization is available. |
MemoryAnalysis
dataclass
¤
MemoryAnalysis(*, baseline_memory_mb: float, peak_memory_mb: float, peak_usage_mb: float, retained_memory_mb: float, memory_efficiency: float, suggestions: tuple[str, ...] = ())
Result of pipeline memory analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
baseline_memory_mb |
float
|
Memory usage before pipeline execution. |
peak_memory_mb |
float
|
Memory usage at peak during execution. |
peak_usage_mb |
float
|
Peak usage above baseline. |
retained_memory_mb |
float
|
Memory retained after GC. |
memory_efficiency |
float
|
Ratio of freed memory to peak usage. |
suggestions |
tuple[str, ...]
|
Optimization suggestions. |
AdaptiveOperation ¤
Hardware-adaptive operations with auto-detection.
Detects the current JAX backend (CPU/GPU/TPU) and provides optimized configuration and shape padding.
optimize_shapes ¤
GPUMemoryProfiler ¤
GPU memory profiling satisfying GPUProfilerProtocol.
Uses multi-fallback strategy: memory_stats -> xla_bridge -> zeros.
MemoryOptimizer ¤
Memory optimization analysis for pipeline functions.
analyze_pipeline_memory ¤
analyze_pipeline_memory(pipeline_fn: Callable[[Any], Any], sample_data: Any) -> MemoryAnalysis | None
Analyze memory usage of a pipeline function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline_fn
|
Callable[[Any], Any]
|
Function to analyze. |
required |
sample_data
|
Any
|
Sample input data. |
required |
Returns:
| Type | Description |
|---|---|
MemoryAnalysis | None
|
MemoryAnalysis with measurements and suggestions, |
MemoryAnalysis | None
|
or None if the pipeline raises an exception. |
hardware ¤
Hardware specifications and detection for profiling.
Provides accelerator specs (TPU v5e, A100, H100, CPU) and utility functions for hardware detection and synchronized execution timing.
HARDWARE_SPECS
module-attribute
¤
HARDWARE_SPECS: dict[str, dict[str, Any]] = {'tpu_v5e': {'peak_flops': 197000000000000.0, 'peak_flops_bf16': 197000000000000.0, 'memory_bandwidth': 1600000000000.0, 'critical_intensity': 123.125}, 'a100_80g': {'peak_flops': 312000000000000.0, 'peak_flops_bf16': 312000000000000.0, 'memory_bandwidth': 2039000000000.0, 'critical_intensity': 153.0, 'tensor_core_shapes': [(16, 16, 16), (16, 16, 8)]}, 'h100': {'peak_flops': 989000000000000.0, 'peak_flops_bf16': 989000000000000.0, 'memory_bandwidth': 3350000000000.0, 'critical_intensity': 295.0, 'tensor_core_shapes': [(16, 16, 16)]}, 'cpu_generic': {'peak_flops': 2000000000000.0, 'peak_flops_bf16': 2000000000000.0, 'memory_bandwidth': 200000000000.0, 'critical_intensity': 10.0, 'simd_width': 8}}
detect_hardware_specs ¤
Detect current hardware and return appropriate specifications.
Uses jax.default_backend() to determine the accelerator type
and returns pre-configured specs for that platform.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Hardware specification dictionary with peak_flops, memory_bandwidth, |
dict[str, Any]
|
and critical_intensity keys (among others). |
measure_execution_time ¤
measure_execution_time(func: Callable[..., Any], inputs: list[Array], warmup: int = 3, iterations: int = 10) -> float
Measure execution time of a JAX function with synchronization.
JIT-compiles the function, runs warmup iterations, then times
iterations executions with block_until_ready() barriers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to benchmark. |
required |
inputs
|
list[Array]
|
Input arguments as a list of arrays. |
required |
warmup
|
int
|
Number of warmup iterations (for JIT compilation). |
3
|
iterations
|
int
|
Number of timed iterations. |
10
|
Returns:
| Type | Description |
|---|---|
float
|
Average execution time in seconds. |
resources ¤
Background resource monitoring with 10Hz sampling.
Provides ResourceMonitor context manager for tracking CPU, memory, and optional GPU utilization during benchmark execution.
GPUProfilerProtocol ¤
ResourceSample
dataclass
¤
ResourceSample(*, timestamp: float, cpu_percent: float, rss_mb: float, gpu_util: float | None, gpu_mem_mb: float | None, gpu_clock_mhz: float | None = None, gpu_power_w: float | None = None)
Single resource measurement at a point in time.
Attributes:
| Name | Type | Description |
|---|---|---|
timestamp |
float
|
Time of measurement (perf_counter). |
cpu_percent |
float
|
CPU utilization percentage. |
rss_mb |
float
|
Resident set size in MB. |
gpu_util |
float | None
|
GPU utilization percentage (None if no GPU). |
gpu_mem_mb |
float | None
|
GPU memory used in MB (None if no GPU). |
ResourceSummary
dataclass
¤
ResourceSummary(*, peak_rss_mb: float, mean_rss_mb: float, peak_gpu_mem_mb: float | None, mean_gpu_util: float | None, memory_growth_mb: float, num_samples: int, duration_sec: float, mean_gpu_clock_mhz: float | None = None, mean_gpu_power_w: float | None = None)
Aggregated resource usage over a monitoring period.
Attributes:
| Name | Type | Description |
|---|---|---|
peak_rss_mb |
float
|
Maximum RSS observed. |
mean_rss_mb |
float
|
Average RSS across all samples. |
peak_gpu_mem_mb |
float | None
|
Maximum GPU memory (None if no GPU). |
mean_gpu_util |
float | None
|
Average GPU utilization (None if no GPU). |
memory_growth_mb |
float
|
Last RSS minus first RSS (positive = growth). |
num_samples |
int
|
Total samples collected. |
duration_sec |
float
|
Time span of monitoring. |
to_dict ¤
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> ResourceSummary
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with resource summary fields. |
required |
Returns:
| Type | Description |
|---|---|
ResourceSummary
|
Reconstructed ResourceSummary instance. |
ResourceMonitor ¤
ResourceMonitor(sample_interval_sec: float = 0.1, gpu_profiler: GPUProfilerProtocol | None = None)
Background 10Hz resource sampling via context manager.
Usage:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_interval_sec
|
float
|
Seconds between samples (default 0.1 = 10Hz). |
0.1
|
gpu_profiler
|
GPUProfilerProtocol | None
|
Optional profiler satisfying GPUProfilerProtocol. |
None
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_interval_sec
|
float
|
Seconds between resource samples. |
0.1
|
gpu_profiler
|
GPUProfilerProtocol | None
|
Optional GPU profiler for GPU metrics. |
None
|
summary
property
¤
summary: ResourceSummary
Compute aggregated summary from collected samples.
Returns:
| Type | Description |
|---|---|
ResourceSummary
|
ResourceSummary with aggregated metrics, or zeroed summary |
ResourceSummary
|
if no samples were collected. |
roofline ¤
Roofline analysis for JAX operations.
Identifies whether operations are compute-bound or memory-bound by comparing arithmetic intensity against hardware roofline limits, and generates optimization recommendations.
RooflineResult
dataclass
¤
RooflineResult(*, arithmetic_intensity: float, critical_intensity: float, memory_bandwidth_utilization: float, flops_utilization: float, bottleneck: str, efficiency: float, execution_time_ms: float, recommendations: tuple[str, ...] = ())
Result of a roofline analysis on a JAX operation.
Attributes:
| Name | Type | Description |
|---|---|---|
arithmetic_intensity |
float
|
Achieved FLOPs per byte of memory traffic. |
critical_intensity |
float
|
Hardware's ridge point (FLOPs/byte). |
memory_bandwidth_utilization |
float
|
Fraction of peak memory bandwidth used. |
flops_utilization |
float
|
Fraction of peak FLOPs achieved. |
bottleneck |
str
|
Either "memory_bandwidth" or "compute". |
efficiency |
float
|
Utilization of the binding resource. |
execution_time_ms |
float
|
Measured execution time in milliseconds. |
recommendations |
tuple[str, ...]
|
Optimization suggestions. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> RooflineResult
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with roofline result fields. |
required |
Returns:
| Type | Description |
|---|---|
RooflineResult
|
Reconstructed RooflineResult instance. |
RooflineAnalyzer
dataclass
¤
RooflineAnalyzer(*, hardware_specs: dict[str, Any] = detect_hardware_specs())
Analyzes operation performance against hardware roofline limits.
Uses measured execution time and estimated FLOPs to determine whether an operation is compute-bound or memory-bound, and how efficiently it uses the available hardware resources.
Attributes:
| Name | Type | Description |
|---|---|---|
hardware_specs |
dict[str, Any]
|
Hardware specification dictionary (auto-detected if not provided). |
hardware_specs
class-attribute
instance-attribute
¤
hardware_specs: dict[str, Any] = field(default_factory=detect_hardware_specs)
analyze_operation ¤
analyze_operation(func: Callable[..., Any], inputs: list[Array], *, flops_override: int | None = None) -> RooflineResult
Perform roofline analysis on a JAX operation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to analyze. |
required |
inputs
|
list[Array]
|
Input arrays for the function. |
required |
flops_override
|
int | None
|
If provided, use this FLOP count instead of estimating.
For accurate results, pass the output of |
None
|
Returns:
| Type | Description |
|---|---|
RooflineResult
|
RooflineResult with bottleneck classification and recommendations. |
timing ¤
Framework-agnostic timing with configurable result synchronization.
Provides TimingSample (frozen dataclass) and TimingCollector for measuring iteration throughput with per-batch timing breakdown. Uses time.perf_counter() exclusively for accurate benchmarking. Supports warm-up iteration exclusion and JIT compilation time measurement.
TimingSample
dataclass
¤
TimingSample(*, wall_clock_sec: float, per_batch_times: tuple[float, ...], first_batch_time: float, num_batches: int, num_elements: int, compilation_time_sec: float | None = None, warmup_batches_excluded: int = 0)
Result of timing an iteration through a data pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
wall_clock_sec |
float
|
Total wall-clock time for the iteration. |
per_batch_times |
tuple[float, ...]
|
Per-batch durations in seconds (warmup batches excluded). |
first_batch_time |
float
|
Time from iteration start to first batch completion. |
num_batches |
int
|
Number of batches consumed (including warmup). |
num_elements |
int
|
Total elements processed (via count_fn). |
compilation_time_sec |
float | None
|
JIT compilation time, if measured separately. |
warmup_batches_excluded |
int
|
Number of leading batches excluded from per_batch_times. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> TimingSample
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with TimingSample fields. |
required |
Returns:
| Type | Description |
|---|---|
TimingSample
|
Reconstructed TimingSample instance. |
TimingCollector ¤
Framework-agnostic timing with configurable GPU sync support.
Uses time.perf_counter() exclusively for accurate benchmarking. Supports configurable result synchronization via sync_fn and warm-up iteration exclusion for JIT-compiled workloads.
JAX dispatches operations asynchronously -- the host returns
immediately while the device is still computing. Without an explicit
synchronization barrier, perf_counter measures only host-side
dispatch latency, not actual compute time. Pass a sync_fn that
calls block_until_ready() on the workload result to force the host
to wait for device completion before recording the timestamp.
Example -- JAX GPU timing with warm-up:
import jax.numpy as jnp
def run_step(batch):
return jax.jit(step_fn)(batch)
collector = TimingCollector(
sync_fn=lambda result: result.block_until_ready(),
warmup_iterations=2,
)
sample = collector.measure_iteration(data_iter, num_batches=50, process_fn=run_step)
# sample.per_batch_times excludes the first 2 batches
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sync_fn
|
Callable[[Any], object] | None
|
Synchronization function called with each batch result.
For JAX: |
None
|
warmup_iterations
|
int
|
Number of initial batches to exclude from per_batch_times statistics. They are still executed (important for JIT warm-up) but omitted from the timing result. Default: 0. |
0
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sync_fn
|
Callable[[Any], object] | None
|
Synchronization function called with each batch result. |
None
|
warmup_iterations
|
int
|
Number of initial batches to exclude from timing stats. |
0
|
measure_iteration ¤
measure_iteration(iterator: Iterator[Any], num_batches: int | None = None, process_fn: Callable[[Any], Any] | None = None, count_fn: Callable[[Any], int] | None = None) -> TimingSample
Measure timing for batches from an iterator.
Warm-up batches (if configured) are executed but excluded from
per_batch_times. wall_clock_sec covers the entire run
including warm-up. num_batches reflects total batches consumed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
iterator
|
Iterator[Any]
|
Data iterator to measure. |
required |
num_batches
|
int | None
|
Max batches to consume (including warmup). None exhausts iterator. |
None
|
process_fn
|
Callable[[Any], Any] | None
|
Optional per-batch function whose execution is timed. Defaults to identity (the yielded batch is treated as result). |
None
|
count_fn
|
Callable[[Any], int] | None
|
Function to count elements per batch. Default: 1 per batch. |
None
|
Returns:
| Type | Description |
|---|---|
TimingSample
|
TimingSample with timing measurements. |
measure_compilation_time ¤
Measure JIT compilation time for a JAX function.
Calls jax.jit(fn).lower(*args).compile() and times it.
This measures the XLA compilation step only, not execution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[..., Any]
|
JAX function to compile. |
required |
*args
|
Any
|
Example arguments for lowering. |
()
|
Returns:
| Type | Description |
|---|---|
float
|
Compilation time in seconds. |
tracing ¤
XLA trace linking for connecting JAX profiler output to benchmark runs.
Provides a simple context manager wrapping jax.profiler.trace()
that records the trace file path for association with Store run metadata.
Does not parse trace files — only links file paths to benchmark results.
TraceReference
dataclass
¤
Reference to a JAX profiler trace output.
Attributes:
| Name | Type | Description |
|---|---|---|
trace_dir |
str
|
Directory where the trace files were written. |
run_id |
str | None
|
Optional benchmark run ID to link the trace to. |
from_dict
classmethod
¤
from_dict(data: dict[str, Any]) -> TraceReference
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with trace reference fields. |
required |
Returns:
| Type | Description |
|---|---|
TraceReference
|
Reconstructed TraceReference instance. |
TraceLinker ¤
Links JAX profiler traces to benchmark runs.
Usage:
linker = TraceLinker()
with linker.trace("/tmp/my_trace") as ref:
# ... run workload ...
print(ref.trace_dir) # "/tmp/my_trace"
trace ¤
trace(log_dir: str | Path, *, run_id: str | None = None, create_perfetto_link: bool = False, create_perfetto_trace: bool = False) -> Any
Start an XLA profiling session and record output metadata.
Wraps jax.profiler.trace() and records the output directory
path as a TraceReference for downstream Store linkage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_dir
|
str | Path
|
Directory for trace output files. |
required |
run_id
|
str | None
|
Optional benchmark run ID to associate with the trace. |
None
|
create_perfetto_link
|
bool
|
Whether to create a Perfetto link (passed to JAX). |
False
|
create_perfetto_trace
|
bool
|
Whether to create a Perfetto trace (passed to JAX). |
False
|
Yields:
| Type | Description |
|---|---|
Any
|
TraceReference with the trace directory and optional run ID. |