Skip to content

Module Protocol¤

Core protocol for Datarax modules.

See Also¤


datarax.core.module ¤

Base module class for all Datarax modules.

This module provides DataraxModule - the base class that all Datarax modules inherit from. It provides common functionality like:

  • Statistics computation and management
  • Caching system
  • Iteration tracking
  • Module copying
  • NNX compliance

Also provides CheckpointableIteratorModule for data sources that need iteration state tracking (position, epoch) for resumable training.

logger module-attribute ¤

logger = getLogger(__name__)

T_co module-attribute ¤

T_co = TypeVar('T_co', covariant=True)

IterationCount ¤

Bases: Variable

Variable type for iteration counters.

This custom Variable type wraps JAX arrays for iteration counters. Using jnp.array(0) instead of plain Python int is critical because:

  1. Python ints are classified as "static" by NNX (not data)
  2. Static values cannot be mutated inside JAX transforms
  3. JAX arrays are classified as "data" and CAN be mutated in transforms
  4. This avoids TraceContextError when mutating inside nnx.jit/nnx.vmap

The custom type also enables StateAxes control:

  • Broadcast: nnx.StateAxes({IterationCount: None})
  • Carry: nnx.StateAxes({IterationCount: nnx.Carry})

DataraxModule ¤

DataraxModule(config: DataraxModuleConfig, *, rngs: Rngs | None = None, name: str | None = None)

Bases: Module

Base class for all Datarax modules.

Provides common functionality shared by all Datarax modules including statistics management, caching, iteration tracking, and module copying.

All modules use config-based initialization with typed, validated config dataclasses.

Parameters:

Name Type Description Default
config DataraxModuleConfig

DataraxModuleConfig (already validated via post_init)

required
rngs Rngs | None

Random number generators (optional)

None
name str | None

Optional name for the module

None

Attributes:

Name Type Description
config

Module configuration

rngs

Random number generators

name

Module name

_cache dict[int, Any] | None

Cache storage (plain dict if cacheable, None otherwise)

_computed_stats Variable[dict[str, Any] | None]

Computed statistics (nnx.Variable)

_applied_count IterationCount

Applied operation counter (IterationCount)

_skipped_count IterationCount

Skipped operation counter (IterationCount)

Parameters:

Name Type Description Default
config DataraxModuleConfig

Module configuration (already validated)

required
rngs Rngs | None

Random number generators

None
name str | None

Optional module name

None

config instance-attribute ¤

config = static(config)

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

CheckpointableIteratorModule ¤

CheckpointableIteratorModule(config: DataraxModuleConfig, *, rngs: Rngs | None = None, name: str | None = None)

Bases: DataraxModule, Generic[T_co]

Base class for iterator modules that can be checkpointed.

This class extends DataraxModule to implement the CheckpointableIterator protocol, providing unified state management for iterators that need to save and restore their position and internal state for resumable training.

Useful for data sources, data loaders, and any module that iterates through data and needs checkpoint/restore capability.

Parameters:

Name Type Description Default
config DataraxModuleConfig

DataraxModuleConfig for the module

required
rngs Rngs | None

Optional Rngs object for randomness

None
name str | None

Optional name for the module

None

Attributes:

Name Type Description
epoch Variable[int | None]

Current epoch (nnx.Variable)

position Variable[int | None]

Current position in iteration (nnx.Variable)

idx Variable[int | None]

Current index (nnx.Variable)

current Variable[Any | None]

Current item being processed (nnx.Variable)

Parameters:

Name Type Description Default
config DataraxModuleConfig

DataraxModuleConfig for the module

required
rngs Rngs | None

Optional Rngs object for randomness

None
name str | None

Optional name for the module

None

epoch instance-attribute ¤

epoch: Variable[int | None] = Variable(None)

position instance-attribute ¤

position: Variable[int | None] = Variable(None)

idx instance-attribute ¤

idx: Variable[int | None] = Variable(None)

current instance-attribute ¤

current: Variable[Any | None] = Variable(None)

config instance-attribute ¤

config = static(config)

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

reset ¤

reset() -> None

Reset the iterator to its initial state.

Subclasses should override this to add additional reset logic.

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.