Skip to content

External Utilities¤

Utilities for external library integration.

See Also¤


datarax.utils.external ¤

External utility functions for Datarax.

This module provides utility functions for working with external libraries and interfaces, particularly focused on JAX and Flax NNX integration.

logger module-attribute ¤

logger = getLogger(__name__)

T module-attribute ¤

T = TypeVar('T')

ExternalAdapterConfig dataclass ¤

ExternalAdapterConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = True, stream_name: str | None = 'augment', batch_strategy: str = 'vmap')

Bases: OperatorConfig

Configuration for ExternalLibraryAdapter.

Inherits from OperatorConfig. Always stochastic since external functions typically require RNG keys.

Attributes:

Name Type Description
stream_name str | None

Name of the RNG stream to use (default: "augment").

stochastic class-attribute instance-attribute ¤

stochastic: bool = True

stream_name class-attribute instance-attribute ¤

stream_name: str | None = 'augment'

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

batch_strategy class-attribute instance-attribute ¤

batch_strategy: str = 'vmap'

ExternalLibraryAdapter ¤

ExternalLibraryAdapter(config: ExternalAdapterConfig, fn: Callable[[dict[str, Any], Array], dict[str, Any]], *, rngs: Rngs | None = None, name: str | None = None)

Bases: OperatorModule

Adapter for external libraries that require raw JAX PRNG keys.

This adapter provides a module-based approach for integrating with external libraries that require raw JAX PRNG keys, maintaining compatibility with NNX transformations like nnx.jit, nnx.vmap, etc.

Use this when you need to:

  • Wrap external functions that use JAX keys in an NNX module
  • Apply NNX transformations (jit, vmap) to functions using JAX keys
  • Integrate external augmentation libraries into Datarax pipelines

Examples:

def augment_fn(data, key): noise = jax.random.normal(key, shape=data["image"].shape) return {**data, "image": data["image"] + noise * 0.1} config = ExternalAdapterConfig() rngs = nnx.Rngs(augment=42) adapter = ExternalLibraryAdapter(config, augment_fn, rngs=rngs) batch = Batch(...) augmented = adapter(batch)

Parameters:

Name Type Description Default
config ExternalAdapterConfig

Configuration for the adapter.

required
fn Callable[[dict[str, Any], Array], dict[str, Any]]

Function that takes (data_dict, key) where data_dict is the element's data dictionary and key is a raw JAX PRNG key.

required
rngs Rngs | None

Rngs object for randomness (required since always stochastic).

None
name str | None

Optional name for the module.

None

fn instance-attribute ¤

fn = fn

config instance-attribute ¤

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = static(stochastic)

stream_name instance-attribute ¤

stream_name = static(stream_name)

get_output_structure ¤

get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]

Declare output structure for vmap axis specification.

ExternalLibraryAdapter.apply() requires random_params which isn't available during jax.eval_shape tracing. We trace through fn with a dummy key.

Parameters:

Name Type Description Default
sample_data PyTree

Single element data (not batched)

required
sample_state PyTree

Single element state (not batched)

required

Returns:

Type Description
tuple[PyTree, PyTree]

Tuple of (output_data_structure, output_state_structure) with 0 leaves.

generate_random_params ¤

generate_random_params(rng: Array, data_shapes: PyTree) -> Array

Generate random key for batch - one key per element.

Parameters:

Name Type Description Default
rng Array

JAX random key

required
data_shapes PyTree

PyTree with data shapes (used to determine batch size)

required

Returns:

Type Description
Array

Array of random keys, one per batch element

apply ¤

apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Array | None = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply the external function to a single element.

Parameters:

Name Type Description Default
data PyTree

Element data dictionary

required
state PyTree

Element state (passed through)

required
metadata dict[str, Any] | None

Element metadata (passed through)

required
random_params Array | None

Random key for this element

None
stats dict[str, Any] | None

Statistics (unused)

None

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any] | None]

Tuple of (transformed_data, state, metadata)

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

apply_batch ¤

apply_batch(batch: Batch, stats: dict[str, Any] | None = None) -> Batch

Process entire batch with vmap and optional RNG generation.

This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.

The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.

Parameters:

Name Type Description Default
batch Batch

Input batch (Batch[Element] structure)

required
stats dict[str, Any] | None

Optional statistics (if None, uses get_statistics())

None

Returns:

Type Description
Batch

Transformed batch with same structure

Note

This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.

output_spec ¤

output_spec(input_spec: PyTree) -> PyTree

Return the operator's output spec given an input spec.

Most operators (normalization, additive noise, simple element-wise transforms) do not change shape; the default returns input_spec unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST override this method.

Parameters:

Name Type Description Default
input_spec PyTree

PyTree of jax.ShapeDtypeStruct describing the input element (matching the upstream DataSourceModule.element_spec() or another operator's output_spec).

required

Returns:

Type Description
PyTree

PyTree of jax.ShapeDtypeStruct describing the operator's output.

PyTree

By default, equal to input_spec.

PureJaxAdapter ¤

PureJaxAdapter(config: ExternalAdapterConfig, fn: Callable[[dict[str, Any]], dict[str, Any]], *, name: str | None = None)

Bases: OperatorModule

Adapter for pure JAX functions (stateless, no RNG).

This adapter wraps pure JAX functions of the form fn(data) -> data. It sets stochastic=False by default and does not generate random params.

Examples:

def normalize(data): return {**data, "image": data["image"] / 255.0}

config = ExternalAdapterConfig(stochastic=False, stream_name=None) adapter = PureJaxAdapter(config, normalize) batch = Batch(...) normalized = adapter(batch)

Parameters:

Name Type Description Default
config ExternalAdapterConfig

Configuration (must have stochastic=False).

required
fn Callable[[dict[str, Any]], dict[str, Any]]

Pure function taking data dict and returning data dict.

required
name str | None

Optional module name.

None

fn instance-attribute ¤

fn = fn

config instance-attribute ¤

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = static(stochastic)

stream_name instance-attribute ¤

stream_name = static(stream_name)

generate_random_params ¤

generate_random_params(rng: Array, data_shapes: PyTree) -> None

No random params for pure functions.

apply ¤

apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Any = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply pure function.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

get_output_structure ¤

get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]

Declare output PyTree structure for vmap axis specification.

Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).

Parameters:

Name Type Description Default
sample_data PyTree

Single element data (not batched)

required
sample_state PyTree

Single element state (not batched)

required

Returns:

Type Description
PyTree

Tuple of (output_data_structure, output_state_structure) with None leaves.

PyTree

The structure (keys/nesting) matters, leaf values are ignored.

Example override for operator that adds keys

def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state

apply_batch ¤

apply_batch(batch: Batch, stats: dict[str, Any] | None = None) -> Batch

Process entire batch with vmap and optional RNG generation.

This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.

The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.

Parameters:

Name Type Description Default
batch Batch

Input batch (Batch[Element] structure)

required
stats dict[str, Any] | None

Optional statistics (if None, uses get_statistics())

None

Returns:

Type Description
Batch

Transformed batch with same structure

Note

This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.

output_spec ¤

output_spec(input_spec: PyTree) -> PyTree

Return the operator's output spec given an input spec.

Most operators (normalization, additive noise, simple element-wise transforms) do not change shape; the default returns input_spec unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST override this method.

Parameters:

Name Type Description Default
input_spec PyTree

PyTree of jax.ShapeDtypeStruct describing the input element (matching the upstream DataSourceModule.element_spec() or another operator's output_spec).

required

Returns:

Type Description
PyTree

PyTree of jax.ShapeDtypeStruct describing the operator's output.

PyTree

By default, equal to input_spec.

to_datarax_operator ¤

to_datarax_operator(fn: Callable[..., Any], stochastic: bool = True, *, stream_name: str | None = 'augment', rngs: Rngs | None = None, name: str | None = None) -> OperatorModule

Convert a function into a Datarax OperatorModule.

This utility simplifies the creation of operator adapters.

Parameters:

Name Type Description Default
fn Callable[..., Any]

The function to adapt.

required
stochastic bool

Whether the function uses randomness.

True
stream_name str | None

Name of the RNG stream (required if stochastic=True).

'augment'
rngs Rngs | None

Rngs object (required if stochastic=True).

None
name str | None

Name of the module.

None

Returns:

Type Description
OperatorModule

An OperatorModule (either ExternalLibraryAdapter or PureJaxAdapter).

Examples:

Pure function¤

op = to_datarax_operator(lambda d: d, stochastic=False)

Stochastic function¤

op = to_datarax_operator(aug_fn, stochastic=True, rngs=rngs)

with_jax_key_wrapper ¤

with_jax_key_wrapper(fn: Callable[[Any, Array], Any]) -> Callable[[Any, RngStream], Any]

Wrap a function that requires a raw JAX PRNG key to work with RngStream.

This function takes a function that expects a raw JAX PRNG key and returns a function that can work with NNX RngStream objects.

Parameters:

Name Type Description Default
fn Callable[[Any, Array], Any]

Function that takes (data, key) where key is a raw JAX PRNG key

required

Returns:

Type Description
Callable[[Any, RngStream], Any]

Function that takes (data, stream) where stream is an RngStream

Examples:

def external_fn(data, key): noise = jax.random.normal(key, shape=data.shape) return data + noise wrapped_fn = with_jax_key_wrapper(external_fn) rngs = nnx.Rngs(augment=42) result = wrapped_fn(data, rngs['augment'])

with_jax_key ¤

with_jax_key(fn: Callable[[Any, Array], Any]) -> Callable[[Any, RngStream], Any]

Decorator version of with_jax_key_wrapper.

This decorator can be applied to functions that require raw JAX PRNG keys to make them compatible with NNX RngStream objects.

Parameters:

Name Type Description Default
fn Callable[[Any, Array], Any]

Function that takes (data, key) where key is a raw JAX PRNG key

required

Returns:

Type Description
Callable[[Any, RngStream], Any]

Function that takes (data, stream) where stream is an RngStream

Examples:

@with_jax_key def external_fn(data, key): noise = jax.random.normal(key, shape=data.shape) return data + noise rngs = nnx.Rngs(augment=42) result = external_fn(data, rngs['augment'])