External Utilities¤
Utilities for external library integration.
See Also¤
- Utils Overview - All utilities
- Data Sources - External data
- HF Source - HuggingFace
- TFDS Source - TensorFlow Datasets
datarax.utils.external ¤
External utility functions for Datarax.
This module provides utility functions for working with external libraries and interfaces, particularly focused on JAX and Flax NNX integration.
ExternalAdapterConfig
dataclass
¤
ExternalAdapterConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = True, stream_name: str | None = 'augment', batch_strategy: str = 'vmap')
Bases: OperatorConfig
Configuration for ExternalLibraryAdapter.
Inherits from OperatorConfig. Always stochastic since external functions typically require RNG keys.
Attributes:
| Name | Type | Description |
|---|---|---|
stream_name |
str | None
|
Name of the RNG stream to use (default: "augment"). |
precomputed_stats
class-attribute
instance-attribute
¤
ExternalLibraryAdapter ¤
ExternalLibraryAdapter(config: ExternalAdapterConfig, fn: Callable[[dict[str, Any], Array], dict[str, Any]], *, rngs: Rngs | None = None, name: str | None = None)
Bases: OperatorModule
Adapter for external libraries that require raw JAX PRNG keys.
This adapter provides a module-based approach for integrating with external libraries that require raw JAX PRNG keys, maintaining compatibility with NNX transformations like nnx.jit, nnx.vmap, etc.
Use this when you need to:
- Wrap external functions that use JAX keys in an NNX module
- Apply NNX transformations (jit, vmap) to functions using JAX keys
- Integrate external augmentation libraries into Datarax pipelines
Examples:
def augment_fn(data, key): noise = jax.random.normal(key, shape=data["image"].shape) return {**data, "image": data["image"] + noise * 0.1} config = ExternalAdapterConfig() rngs = nnx.Rngs(augment=42) adapter = ExternalLibraryAdapter(config, augment_fn, rngs=rngs) batch = Batch(...) augmented = adapter(batch)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ExternalAdapterConfig
|
Configuration for the adapter. |
required |
fn
|
Callable[[dict[str, Any], Array], dict[str, Any]]
|
Function that takes (data_dict, key) where data_dict is the element's data dictionary and key is a raw JAX PRNG key. |
required |
rngs
|
Rngs | None
|
Rngs object for randomness (required since always stochastic). |
None
|
name
|
str | None
|
Optional name for the module. |
None
|
get_output_structure ¤
get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]
Declare output structure for vmap axis specification.
ExternalLibraryAdapter.apply() requires random_params which isn't available during jax.eval_shape tracing. We trace through fn with a dummy key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_data
|
PyTree
|
Single element data (not batched) |
required |
sample_state
|
PyTree
|
Single element state (not batched) |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree]
|
Tuple of (output_data_structure, output_state_structure) with 0 leaves. |
generate_random_params ¤
apply ¤
apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Array | None = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply the external function to a single element.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Element data dictionary |
required |
state
|
PyTree
|
Element state (passed through) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata (passed through) |
required |
random_params
|
Array | None
|
Random key for this element |
None
|
stats
|
dict[str, Any] | None
|
Statistics (unused) |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata) |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
apply_batch ¤
Process entire batch with vmap and optional RNG generation.
This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.
The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Input batch (Batch[Element] structure) |
required |
stats
|
dict[str, Any] | None
|
Optional statistics (if None, uses get_statistics()) |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
Transformed batch with same structure |
Note
This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.
output_spec ¤
Return the operator's output spec given an input spec.
Most operators (normalization, additive noise, simple element-wise
transforms) do not change shape; the default returns input_spec
unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST
override this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_spec
|
PyTree
|
PyTree of |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree of |
PyTree
|
By default, equal to |
PureJaxAdapter ¤
PureJaxAdapter(config: ExternalAdapterConfig, fn: Callable[[dict[str, Any]], dict[str, Any]], *, name: str | None = None)
Bases: OperatorModule
Adapter for pure JAX functions (stateless, no RNG).
This adapter wraps pure JAX functions of the form fn(data) -> data.
It sets stochastic=False by default and does not generate random params.
Examples:
def normalize(data): return {**data, "image": data["image"] / 255.0}
config = ExternalAdapterConfig(stochastic=False, stream_name=None) adapter = PureJaxAdapter(config, normalize) batch = Batch(...) normalized = adapter(batch)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ExternalAdapterConfig
|
Configuration (must have stochastic=False). |
required |
fn
|
Callable[[dict[str, Any]], dict[str, Any]]
|
Pure function taking data dict and returning data dict. |
required |
name
|
str | None
|
Optional module name. |
None
|
generate_random_params ¤
generate_random_params(rng: Array, data_shapes: PyTree) -> None
No random params for pure functions.
apply ¤
apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Any = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply pure function.
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
get_output_structure ¤
get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]
Declare output PyTree structure for vmap axis specification.
Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_data
|
PyTree
|
Single element data (not batched) |
required |
sample_state
|
PyTree
|
Single element state (not batched) |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (output_data_structure, output_state_structure) with None leaves. |
PyTree
|
The structure (keys/nesting) matters, leaf values are ignored. |
Example override for operator that adds keys
def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state
apply_batch ¤
Process entire batch with vmap and optional RNG generation.
This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.
The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Input batch (Batch[Element] structure) |
required |
stats
|
dict[str, Any] | None
|
Optional statistics (if None, uses get_statistics()) |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
Transformed batch with same structure |
Note
This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.
output_spec ¤
Return the operator's output spec given an input spec.
Most operators (normalization, additive noise, simple element-wise
transforms) do not change shape; the default returns input_spec
unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST
override this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_spec
|
PyTree
|
PyTree of |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree of |
PyTree
|
By default, equal to |
to_datarax_operator ¤
to_datarax_operator(fn: Callable[..., Any], stochastic: bool = True, *, stream_name: str | None = 'augment', rngs: Rngs | None = None, name: str | None = None) -> OperatorModule
Convert a function into a Datarax OperatorModule.
This utility simplifies the creation of operator adapters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[..., Any]
|
The function to adapt. |
required |
stochastic
|
bool
|
Whether the function uses randomness. |
True
|
stream_name
|
str | None
|
Name of the RNG stream (required if stochastic=True). |
'augment'
|
rngs
|
Rngs | None
|
Rngs object (required if stochastic=True). |
None
|
name
|
str | None
|
Name of the module. |
None
|
Returns:
| Type | Description |
|---|---|
OperatorModule
|
An OperatorModule (either ExternalLibraryAdapter or PureJaxAdapter). |
Examples:
Pure function¤
op = to_datarax_operator(lambda d: d, stochastic=False)
Stochastic function¤
op = to_datarax_operator(aug_fn, stochastic=True, rngs=rngs)
with_jax_key_wrapper ¤
Wrap a function that requires a raw JAX PRNG key to work with RngStream.
This function takes a function that expects a raw JAX PRNG key and returns a function that can work with NNX RngStream objects.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[[Any, Array], Any]
|
Function that takes (data, key) where key is a raw JAX PRNG key |
required |
Returns:
| Type | Description |
|---|---|
Callable[[Any, RngStream], Any]
|
Function that takes (data, stream) where stream is an RngStream |
Examples:
def external_fn(data, key): noise = jax.random.normal(key, shape=data.shape) return data + noise wrapped_fn = with_jax_key_wrapper(external_fn) rngs = nnx.Rngs(augment=42) result = wrapped_fn(data, rngs['augment'])
with_jax_key ¤
Decorator version of with_jax_key_wrapper.
This decorator can be applied to functions that require raw JAX PRNG keys to make them compatible with NNX RngStream objects.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[[Any, Array], Any]
|
Function that takes (data, key) where key is a raw JAX PRNG key |
required |
Returns:
| Type | Description |
|---|---|
Callable[[Any, RngStream], Any]
|
Function that takes (data, stream) where stream is an RngStream |
Examples:
@with_jax_key def external_fn(data, key): noise = jax.random.normal(key, shape=data.shape) return data + noise rngs = nnx.Rngs(augment=42) result = external_fn(data, rngs['augment'])