Selector Operator¤
Select specific fields or elements from data.
See Also¤
- Operators Overview - All operator types
- Field Operators - Field transforms
- Map Operator - Mapping functions
- Operators Tutorial
datarax.operators.selector_operator ¤
SelectorOperator - Random selection from multiple operators.
This operator wraps multiple OperatorModules and randomly selects ONE to apply per batch element.
Key Features:
- Wraps multiple OperatorModules with random selection
- Configurable weights for weighted random selection (defaults to uniform)
- Uses jax.lax.switch for JIT-compatible dynamic selection
- Always stochastic (always makes a random choice)
- Full JAX compatibility (JIT, vmap)
Examples:
Basic usage:
config = SelectorOperatorConfig(
operators=[op1, op2, op3],
weights=[0.5, 0.3, 0.2]
)
op = SelectorOperator(config, rngs=rngs)
SelectorOperatorConfig
dataclass
¤
SelectorOperatorConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, batch_strategy: str = 'vmap', *, operators: list[OperatorModule], weights: list[float] | None = None)
Bases: OperatorConfig
Configuration for SelectorOperator.
Extends OperatorConfig with operators list and optional weights.
Attributes:
| Name | Type | Description |
|---|---|---|
operators |
list[OperatorModule]
|
List of operators to select from (minimum 1) |
weights |
list[float] | None
|
Optional weights for random selection (defaults to uniform) Will be normalized to sum to 1.0 |
Note:
- stochastic is always True (always makes random choice)
- stream_name defaults to "augment" for random selection
SelectorOperator ¤
SelectorOperator(config: SelectorOperatorConfig, *, rngs: Rngs | None = None)
Bases: OperatorModule
Wrapper operator that randomly selects ONE operator to apply.
Wraps multiple OperatorModules and uses weighted random selection to choose which one to apply per batch element.
Uses jax.lax.switch for JIT-compatible operator selection with the unified operator interface.
Examples:
op1 = BrightnessOperator(brightness_config, rngs=nnx.Rngs(0)) # Different transforms
op2 = NoiseOperator(noise_config, rngs=nnx.Rngs(0))
op3 = RotationOperator(rotation_config, rngs=nnx.Rngs(0))
selector_config = SelectorOperatorConfig( # 50% brightness, 30% noise, 20% rotation
operators=[op1, op2, op3],
weights=[0.5, 0.3, 0.2]
)
selector = SelectorOperator(selector_config, rngs=nnx.Rngs(0))
result_batch = selector(batch) # Each element gets one randomly selected operator
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SelectorOperatorConfig
|
SelectorOperatorConfig with operators list and optional weights |
required |
rngs
|
Rngs | None
|
Random number generators (required for random selection) |
None
|
get_output_structure ¤
get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]
Declare output structure using first operator.
SelectorOperator's apply() requires random_params which isn't available during jax.eval_shape tracing. Since all child operators should produce compatible output structures, we use the first operator's structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_data
|
PyTree
|
Single element data (not batched) |
required |
sample_state
|
PyTree
|
Single element state (not batched) |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree]
|
Tuple of (output_data_structure, output_state_structure) with 0 leaves. |
generate_random_params ¤
Generate random operator selection indices for each batch element.
Creates integer indices determining which operator to apply per element, plus delegates to all child operators for their random params.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Array
|
JAX random key |
required |
data_shapes
|
PyTree
|
PyTree with same structure as batch.data, containing shapes Examples: {"image": (batch_size, H, W, C)} |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dict with:
|
apply ¤
apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Any = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply the randomly selected operator to the data.
Uses jax.lax.switch for JIT-compatible operator selection based on the pre-generated random index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Element data PyTree (no batch dimension) |
required |
state
|
PyTree
|
Element state PyTree |
required |
metadata
|
dict[str, Any] | None
|
Element metadata |
required |
random_params
|
Any
|
Dict with "selected_indices" (int) and "child_params" |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics |
None
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
Tuple of (transformed_data, state, metadata) from selected operator |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
apply_batch ¤
Process entire batch with vmap and optional RNG generation.
This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.
The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Input batch (Batch[Element] structure) |
required |
stats
|
dict[str, Any] | None
|
Optional statistics (if None, uses get_statistics()) |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
Transformed batch with same structure |
Note
This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.
output_spec ¤
Return the operator's output spec given an input spec.
Most operators (normalization, additive noise, simple element-wise
transforms) do not change shape; the default returns input_spec
unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST
override this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_spec
|
PyTree
|
PyTree of |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree of |
PyTree
|
By default, equal to |