XLA Optimization¤
JAX/XLA compilation hints and tuning.
See Also¤
- Performance Overview - All performance tools
- Roofline Model - Bottleneck analysis
- NNX Best Practices - JAX patterns
- Troubleshooting
datarax.performance.xla_optimization ¤
XLA compilation optimization for maximum performance.
This module provides XLA compilation strategies, smart compilation caching, and memory-efficient patterns for JAX operations.
XLAOptimizer ¤
XLAOptimizer(target_hardware: str = 'auto')
Configure XLA compiler for maximum performance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target_hardware
|
str
|
Target hardware ('auto', 'gpu', 'tpu', 'cpu') |
'auto'
|
setup_xla_flags ¤
Configure XLA compiler flags for maximum performance.
Delegates to :func:apply_xla_flags after resolving "auto"
to the actual backend.
SmartCompilation ¤
Intelligent compilation strategies for different scenarios.
adaptive_jit ¤
aot_compile ¤
Perform Ahead-of-Time (AOT) compilation for specific input shapes.
This eliminates the first-run compilation cost (jitter), ideal for serving.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable
|
Function to compile |
required |
args
|
Any
|
Example arguments (for shape/dtype) |
()
|
kwargs
|
Any
|
Example keyword arguments |
{}
|
Returns:
| Type | Description |
|---|---|
Callable
|
Compiled function ready for execution (lowered and compiled) |
shard_map_jit ¤
MemoryEfficientCompilation ¤
Compilation patterns optimized for memory efficiency.
donate_wrapper
staticmethod
¤
Wrapper to automatically donate large arrays.
Caches the compiled function keyed by donate_argnums to avoid recompilation on every call.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable
|
Function to wrap |
required |
donate_args
|
tuple[int, ...] | None
|
Indices of arguments to donate, or None for auto-detect |
None
|
Returns:
| Type | Description |
|---|---|
Callable
|
Memory-efficient wrapped function |
parameter_update_pattern
staticmethod
¤
with_rematerialization
staticmethod
¤
Apply gradient checkpointing (rematerialization) to reduce memory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable
|
Function to checkpoint |
required |
policy
|
Callable | None
|
Checkpoint policy (optional, e.g. jax.checkpoint_policies.save_nothing) |
None
|
Returns:
| Type | Description |
|---|---|
Callable
|
Function with checkpointing applied |
DistributedUtils ¤
get_xla_flags ¤
Return hardware-specific XLA compiler flags for the given backend.
Selects flags based on the backend to avoid setting backend-specific flags (e.g. GPU Triton flags) on incompatible backends which would cause a fatal XLA parse error.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
backend
|
str
|
Target backend ('gpu', 'tpu', 'cpu'). Pass |
required |
Returns:
| Type | Description |
|---|---|
list[str]
|
List of |
apply_xla_flags ¤
apply_xla_flags(backend: str) -> None
Apply hardware-specific XLA flags to the XLA_FLAGS environment variable.
Merges flags returned by :func:get_xla_flags into the existing
XLA_FLAGS env var, skipping any flag whose name is already present.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
backend
|
str
|
Target backend ('gpu', 'tpu', 'cpu'). |
required |