Skip to content

XLA Optimization¤

JAX/XLA compilation hints and tuning.

See Also¤


datarax.performance.xla_optimization ¤

XLA compilation optimization for maximum performance.

This module provides XLA compilation strategies, smart compilation caching, and memory-efficient patterns for JAX operations.

logger module-attribute ¤

logger = getLogger(__name__)

XLAOptimizer ¤

XLAOptimizer(target_hardware: str = 'auto')

Configure XLA compiler for maximum performance.

Parameters:

Name Type Description Default
target_hardware str

Target hardware ('auto', 'gpu', 'tpu', 'cpu')

'auto'

target_hardware instance-attribute ¤

target_hardware = target_hardware

setup_xla_flags ¤

setup_xla_flags() -> None

Configure XLA compiler flags for maximum performance.

Delegates to :func:apply_xla_flags after resolving "auto" to the actual backend.

setup_jax_config ¤

setup_jax_config() -> None

Configure JAX-specific optimizations.

setup_compilation_cache ¤

setup_compilation_cache() -> None

Setup persistent compilation cache.

SmartCompilation ¤

SmartCompilation()

Intelligent compilation strategies for different scenarios.

compilation_cache instance-attribute ¤

compilation_cache: dict[Any, Callable] = {}

shape_signatures instance-attribute ¤

shape_signatures: dict[Any, Any] = {}

adaptive_jit ¤

adaptive_jit(func: Callable, static_threshold: int = 1000) -> Callable

Apply JIT compilation adaptively based on input characteristics.

Parameters:

Name Type Description Default
func Callable

Function to potentially compile

required
static_threshold int

Size threshold for JIT compilation

1000

Returns:

Type Description
Callable

Adaptively compiled function

aot_compile ¤

aot_compile(func: Callable, *args: Any, **kwargs: Any) -> Callable

Perform Ahead-of-Time (AOT) compilation for specific input shapes.

This eliminates the first-run compilation cost (jitter), ideal for serving.

Parameters:

Name Type Description Default
func Callable

Function to compile

required
args Any

Example arguments (for shape/dtype)

()
kwargs Any

Example keyword arguments

{}

Returns:

Type Description
Callable

Compiled function ready for execution (lowered and compiled)

shard_map_jit ¤

shard_map_jit(mesh: Mesh, in_specs: Any, out_specs: Any) -> Callable

Create a shard_map (SPMD) function for expert parallelization.

Parameters:

Name Type Description Default
mesh Mesh

Device mesh

required
in_specs Any

Input PartitionSpecs

required
out_specs Any

Output PartitionSpecs

required

Returns:

Type Description
Callable

Decorator for the function

MemoryEfficientCompilation ¤

Compilation patterns optimized for memory efficiency.

donate_wrapper staticmethod ¤

donate_wrapper(func: Callable, donate_args: tuple[int, ...] | None = None) -> Callable

Wrapper to automatically donate large arrays.

Caches the compiled function keyed by donate_argnums to avoid recompilation on every call.

Parameters:

Name Type Description Default
func Callable

Function to wrap

required
donate_args tuple[int, ...] | None

Indices of arguments to donate, or None for auto-detect

None

Returns:

Type Description
Callable

Memory-efficient wrapped function

parameter_update_pattern staticmethod ¤

parameter_update_pattern(learning_rate: float = 0.01) -> Callable

Optimized parameter update with buffer donation.

Parameters:

Name Type Description Default
learning_rate float

Learning rate for updates

0.01

Returns:

Type Description
Callable

Memory-efficient update function

with_rematerialization staticmethod ¤

with_rematerialization(func: Callable, policy: Callable | None = None) -> Callable

Apply gradient checkpointing (rematerialization) to reduce memory.

Parameters:

Name Type Description Default
func Callable

Function to checkpoint

required
policy Callable | None

Checkpoint policy (optional, e.g. jax.checkpoint_policies.save_nothing)

None

Returns:

Type Description
Callable

Function with checkpointing applied

DistributedUtils ¤

Best practices for distributed computation and sharding.

create_mesh staticmethod ¤

create_mesh(axis_dims: tuple[int, ...], axis_names: tuple[str, ...]) -> Mesh

Create a device mesh for parallelism.

Parameters:

Name Type Description Default
axis_dims tuple[int, ...]

Dimensions for mesh (e.g. (4, 2) for 8 devices)

required
axis_names tuple[str, ...]

Names for axes (e.g. ('data', 'model'))

required

Returns:

Type Description
Mesh

jax.sharding.Mesh

with_sharding staticmethod ¤

with_sharding(x: Any, mesh: Mesh, partition_spec: Any) -> Any

Apply sharding constraint to an array.

Parameters:

Name Type Description Default
x Any

JAX array

required
mesh Mesh

Device mesh

required
partition_spec Any

PartitionSpec for the array

required

Returns:

Type Description
Any

Sharded array

CompilationProfiler ¤

CompilationProfiler()

Profile and optimize JAX compilation performance.

compilation_times instance-attribute ¤

compilation_times: dict[Any, float] = {}

execution_times instance-attribute ¤

execution_times: dict[Any, list[float]] = {}

cache_status instance-attribute ¤

cache_status: dict[Any, list[bool]] = {}

cache_hits instance-attribute ¤

cache_hits: int = 0

cache_misses instance-attribute ¤

cache_misses: int = 0

shape_profiles instance-attribute ¤

shape_profiles: dict[Any, dict] = {}

profile_function ¤

profile_function(func_name: str, enable_detailed_logging: bool = False) -> Callable

Decorator to profile function compilation and execution.

Parameters:

Name Type Description Default
func_name str

Name of the function for reporting

required
enable_detailed_logging bool

Enable detailed JAX logging

False

Returns:

Type Description
Callable

Profiling decorator

generate_report ¤

generate_report() -> dict[str, Any]

Generate compilation report.

Returns:

Type Description
dict[str, Any]

Report dictionary with analysis and recommendations

get_xla_flags ¤

get_xla_flags(backend: str) -> list[str]

Return hardware-specific XLA compiler flags for the given backend.

Selects flags based on the backend to avoid setting backend-specific flags (e.g. GPU Triton flags) on incompatible backends which would cause a fatal XLA parse error.

Parameters:

Name Type Description Default
backend str

Target backend ('gpu', 'tpu', 'cpu'). Pass jax.default_backend() for auto-detection.

required

Returns:

Type Description
list[str]

List of --xla_* flag strings appropriate for the backend.

apply_xla_flags ¤

apply_xla_flags(backend: str) -> None

Apply hardware-specific XLA flags to the XLA_FLAGS environment variable.

Merges flags returned by :func:get_xla_flags into the existing XLA_FLAGS env var, skipping any flag whose name is already present.

Parameters:

Name Type Description Default
backend str

Target backend ('gpu', 'tpu', 'cpu').

required