Map Operator¤
Apply functions element-wise across data.
See Also¤
- Operators Overview - All operator types
- Element Operator - Element-level transforms
- Field Operators - Field-specific transforms
- Operators Tutorial
datarax.operators.map_operator ¤
MapOperator - operator for applying functions to array leaves.
This module provides MapOperator, which applies user-provided array transformation functions to leaves in element data PyTree.
Key Features:
- Unified function signature: fn(x: Array, key: Array) -> Array
- Deterministic mode: key parameter ignored
- Stochastic mode: key parameter provides per-leaf randomness
- Full-tree mode: Apply fn to all array leaves
- Subtree mode: Apply fn only to specified subtree leaves
- Uses jax.tree.map_with_path for unified implementation
BREAKING CHANGE: User functions MUST accept key parameter even in deterministic mode.
MapOperator ¤
MapOperator(config: MapOperatorConfig, fn: Callable[[Array, Array], Array], *, rngs: Rngs | None = None, name: str | None = None)
Bases: OperatorModule
Unified operator for mapping functions over array leaves in data.
Applies user-provided array transformation function to leaves in element.data PyTree. Supports both full-tree and subtree transformations, both deterministic and stochastic modes.
User Function Signature (ALWAYS required):
fn(x: jax.Array, key: jax.Array) -> jax.Array
- Deterministic mode (stochastic=False): Ignore key parameter
- Stochastic mode (stochastic=True): Use key for randomness
Two operational modes: 1. Full-tree mode (subtree=None): Apply fn to all array leaves - Unified implementation with jax.tree.map_with_path
- Subtree mode (subtree specified): Apply fn only to subtree leaves
- Path-based filtering via keypath matching
- Other leaves pass through unchanged
Examples:
Deterministic full-tree (ignore key)¤
config = MapOperatorConfig(subtree=None, stochastic=False) op = MapOperator(config, fn=lambda x, key: (x - 0.5) / 0.5, rngs=rngs)
Stochastic full-tree (use key for noise)¤
config = MapOperatorConfig(subtree=None, stochastic=True, stream_name="augment") op = MapOperator( config, fn=lambda x, key: x + jax.random.normal(key, x.shape) * 0.1, rngs=rngs )
Stochastic subtree (only augment image)¤
config = MapOperatorConfig( subtree={"image": None}, stochastic=True, stream_name="augment" ) op = MapOperator( config, fn=lambda x, key: x + jax.random.normal(key, x.shape) * 0.1, rngs=rngs )
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MapOperatorConfig
|
Operator configuration |
required |
fn
|
Callable[[Array, Array], Array]
|
User function with signature: fn(x: Array, key: Array) -> Array BREAKING CHANGE: Must accept key parameter even for deterministic mode - Deterministic: ignore key parameter - Stochastic: use key for randomness |
required |
rngs
|
Rngs | None
|
Random number generators (required if stochastic=True, optional otherwise) |
None
|
name
|
str | None
|
Optional name for the operator |
None
|
generate_random_params ¤
generate_random_params(rng: Array, data_shapes: PyTree) -> PyTree | None
Generate random parameters for batch transformation.
Generates PyTree of RNG keys matching data structure, with one key per batch element for each leaf. This enables per-leaf, per-element randomness.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Array
|
JAX random key (single key for entire batch) |
required |
data_shapes
|
PyTree
|
PyTree with same structure as batch.data, containing shapes Examples: {"image": (batch_size, H, W, C)} |
required |
Returns:
| Name | Type | Description |
|---|---|---|
PyTree | None
|
PyTree of keys matching data structure, each leaf is Array[batch_size, 2] |
|
Examples |
PyTree | None
|
{"image": Array[batch_size, 2]} where 2 is PRNGKey shape, |
PyTree | None
|
or None for deterministic operators. |
Implementation
- Flatten data_shapes to get list of shapes
- Extract batch_size from first shape
- Split rng into n_leaves keys (one per leaf type)
- For each leaf key, split into batch_size keys
- Unflatten into PyTree matching original structure
apply ¤
apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Any = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply array transformation to element (unified implementation).
Single method handles all four modes:
- Full-tree × deterministic
- Full-tree × stochastic
- Subtree × deterministic
- Subtree × stochastic
Uses jax.tree.map_with_path for unified traversal with keypath filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Element data PyTree |
required |
state
|
PyTree
|
Element state PyTree (unchanged) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata dict (unchanged) |
required |
random_params
|
Any
|
PyTree of keys (stochastic) or dummy keys (deterministic) |
None
|
stats
|
dict[str, Any] | None
|
Optional batch statistics (unused) |
None
|
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (transformed_data, state, metadata) |
PyTree
|
where state and metadata are unchanged |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
get_output_structure ¤
get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]
Declare output PyTree structure for vmap axis specification.
Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_data
|
PyTree
|
Single element data (not batched) |
required |
sample_state
|
PyTree
|
Single element state (not batched) |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (output_data_structure, output_state_structure) with None leaves. |
PyTree
|
The structure (keys/nesting) matters, leaf values are ignored. |
Example override for operator that adds keys
def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state
apply_batch ¤
Process entire batch with vmap and optional RNG generation.
This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.
The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Input batch (Batch[Element] structure) |
required |
stats
|
dict[str, Any] | None
|
Optional statistics (if None, uses get_statistics()) |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
Transformed batch with same structure |
Note
This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.
output_spec ¤
Return the operator's output spec given an input spec.
Most operators (normalization, additive noise, simple element-wise
transforms) do not change shape; the default returns input_spec
unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST
override this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_spec
|
PyTree
|
PyTree of |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree of |
PyTree
|
By default, equal to |