Module Protocol¤
Core protocol for Datarax modules.
See Also¤
- Core Overview - All core protocols
- NNX Best Practices - Module patterns
- Checkpointing - Module checkpoints
- Distributed Training
datarax.core.module ¤
Base module class for all Datarax modules.
This module provides DataraxModule - the base class that all Datarax modules inherit from. It provides common functionality like:
- Statistics computation and management
- Caching system
- Iteration tracking
- Module copying
- NNX compliance
Also provides CheckpointableIteratorModule for data sources that need iteration state tracking (position, epoch) for resumable training.
IterationCount ¤
Bases: Variable
Variable type for iteration counters.
This custom Variable type wraps JAX arrays for iteration counters. Using jnp.array(0) instead of plain Python int is critical because:
- Python ints are classified as "static" by NNX (not data)
- Static values cannot be mutated inside JAX transforms
- JAX arrays are classified as "data" and CAN be mutated in transforms
- This avoids TraceContextError when mutating inside nnx.jit/nnx.vmap
The custom type also enables StateAxes control:
- Broadcast:
nnx.StateAxes({IterationCount: None}) - Carry:
nnx.StateAxes({IterationCount: nnx.Carry})
DataraxModule ¤
DataraxModule(config: DataraxModuleConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: Module
Base class for all Datarax modules.
Provides common functionality shared by all Datarax modules including statistics management, caching, iteration tracking, and module copying.
All modules use config-based initialization with typed, validated config dataclasses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig
|
DataraxModuleConfig (already validated via post_init) |
required |
rngs
|
Rngs | None
|
Random number generators (optional) |
None
|
name
|
str | None
|
Optional name for the module |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
Module configuration |
|
rngs |
Random number generators |
|
name |
Module name |
|
_cache |
dict[int, Any] | None
|
Cache storage (plain dict if cacheable, None otherwise) |
_computed_stats |
Variable[dict[str, Any] | None]
|
Computed statistics (nnx.Variable) |
_applied_count |
IterationCount
|
Applied operation counter (IterationCount) |
_skipped_count |
IterationCount
|
Skipped operation counter (IterationCount) |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig
|
Module configuration (already validated) |
required |
rngs
|
Rngs | None
|
Random number generators |
None
|
name
|
str | None
|
Optional module name |
None
|
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
CheckpointableIteratorModule ¤
CheckpointableIteratorModule(config: DataraxModuleConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: DataraxModule, Generic[T_co]
Base class for iterator modules that can be checkpointed.
This class extends DataraxModule to implement the CheckpointableIterator protocol, providing unified state management for iterators that need to save and restore their position and internal state for resumable training.
Useful for data sources, data loaders, and any module that iterates through data and needs checkpoint/restore capability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig
|
DataraxModuleConfig for the module |
required |
rngs
|
Rngs | None
|
Optional Rngs object for randomness |
None
|
name
|
str | None
|
Optional name for the module |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
epoch |
Variable[int | None]
|
Current epoch (nnx.Variable) |
position |
Variable[int | None]
|
Current position in iteration (nnx.Variable) |
idx |
Variable[int | None]
|
Current index (nnx.Variable) |
current |
Variable[Any | None]
|
Current item being processed (nnx.Variable) |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig
|
DataraxModuleConfig for the module |
required |
rngs
|
Rngs | None
|
Optional Rngs object for randomness |
None
|
name
|
str | None
|
Optional name for the module |
None
|
reset ¤
Reset the iterator to its initial state.
Subclasses should override this to add additional reset logic.
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |