Structural Types¤
Structural type definitions for data layouts.
See Also¤
- Core Overview - All core protocols
- Element Batch - Element/Batch types
- Types & Protocols - Type definitions
- Config - Configuration types
datarax.core.structural ¤
StructuralModule - unified non-parametric structural processor module.
This module provides StructuralModule, which unifies BatcherModule, SamplerModule, SharderModule, and other structural processors into a single base class for all non-parametric, structural data organization operations.
Key Features:
- Config-based initialization with StructuralConfig (frozen/immutable)
- Stochastic mode (with RNG for random organization)
- Deterministic mode (fixed organization)
- Single process() method for all structural operations
- No learnable parameters (compile-time constants only)
- JIT compatibility
- Statistics system (inherited from DataraxModule)
StructuralModule ¤
StructuralModule(config: StructuralConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: DataraxModule
Base class for non-parametric structural processors.
Structural modules organize/reorganize data without learnable parameters. Configuration is immutable (frozen dataclass) representing compile-time constants.
Structural modules change data structure/organization, not data values. They are NOT differentiable and have no learnable parameters.
The structural pattern uses a single process() method:
- process() - Transforms input structure (abstract method)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
StructuralConfig
|
StructuralConfig (already validated via post_init, frozen) |
required |
rngs
|
Rngs | None
|
Random number generators (required if stochastic=True) |
None
|
name
|
str | None
|
Optional name for the structural module |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
Structural module configuration (immutable) |
|
stochastic |
Whether this module uses randomness (from config) |
|
stream_name |
RNG stream name (from config, required if stochastic=True) |
Examples:
Deterministic batcher:
config = BatcherConfig(stochastic=False, batch_size=32)
batcher = BatcherModule(config)
batches = batcher.process(elements)
Stochastic sampler:
config = SamplerConfig(stochastic=True, stream_name="sampler", num_samples=100)
sampler = SamplerModule(config, rngs=nnx.Rngs(42))
indices = sampler.process(dataset_size=1000)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
StructuralConfig
|
Structural module configuration (already validated, frozen) |
required |
rngs
|
Rngs | None
|
Random number generators (required if stochastic=True) |
None
|
name
|
str | None
|
Optional module name |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If stochastic=True but rngs is None |
process ¤
Process input structure.
This method transforms the structure/organization of input data without modifying the data values themselves.
Subclasses MUST implement this method.
The input/output types depend on the specific structural processor:
- Batcher: list[Element] -> list[Batch]
- Sampler: int -> list[int]
- Sharder: Batch -> Sharded[Batch]
- Splitter: Dataset -> tuple[Dataset, Dataset]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Any
|
Input to process (type varies by processor) |
required |
*args
|
Any
|
Additional positional arguments (processor-specific) |
()
|
**kwargs
|
Any
|
Additional keyword arguments (processor-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Processed output (type varies by processor) |
Examples:
Batcher implementation:
def process(self, elements: list[Element]) -> list[Batch]:
batches = []
for i in range(0, len(elements), self.config.batch_size):
batch_elements = elements[i:i + self.config.batch_size]
batches.append(Batch.from_elements(batch_elements))
return batches
Sampler implementation (deterministic):
def process(self, dataset_size: int) -> list[int]:
return list(range(min(self.config.num_samples, dataset_size)))
Sampler implementation (stochastic):
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |