Noise Operator¤
Add random noise for data augmentation and regularization.
See Also¤
- Operators Overview - All operator types
- Dropout Operator - Random dropout
- Probabilistic Operator - Random augmentation
- Operators Tutorial
datarax.operators.modality.image.noise_operator ¤
NoiseOperator - Operator for image noise augmentation.
This operator extends ModalityOperator to provide three types of noise:
- Gaussian: Additive Gaussian noise
- Salt & Pepper: Impulse noise (random pixels to min/max)
- Poisson: Shot noise (photon noise simulation)
Key Features:
- Three noise types via 'mode' parameter
- Stochastic mode with pre-generated noise
- Deterministic mode for reproducible noise patterns
- Full JAX compatibility with JIT compilation
Examples:
Basic usage:
config = NoiseOperatorConfig(
field_key="image",
mode="gaussian",
noise_std=0.05,
noise_mean=0.0
)
op = NoiseOperator(config, rngs=rngs)
NoiseOperatorConfig
dataclass
¤
NoiseOperatorConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, batch_strategy: str = 'vmap', *, field_key: str, target_key: str | None = None, auxiliary_fields: list[str] | None = None, clip_range: tuple[float, float] | None = (0.0, 1.0), preserve_auxiliary: bool = True, validate_domain_constraints: bool = True, mode: Literal['gaussian', 'salt_pepper', 'poisson'] = 'gaussian', noise_std: float = 0.05, noise_mean: float = 0.0, salt_prob: float = 0.01, pepper_prob: float = 0.01, salt_value: float | None = None, pepper_value: float | None = None, lam_scale: float = 1.0)
Bases: ModalityOperatorConfig
Configuration for NoiseOperator.
Extends ModalityOperatorConfig with noise-specific parameters.
Attributes:
| Name | Type | Description |
|---|---|---|
mode |
Literal['gaussian', 'salt_pepper', 'poisson']
|
Type of noise to apply:
Default: "gaussian" |
# |
Gaussian mode parameters
|
|
noise_std |
float
|
Standard deviation for Gaussian noise. Default: 0.05 |
noise_mean |
float
|
Mean for Gaussian noise. Default: 0.0 |
# |
Salt & Pepper mode parameters
|
|
salt_prob |
float
|
Probability of salt (max value) pixels. Default: 0.01 |
pepper_prob |
float
|
Probability of pepper (min value) pixels. Default: 0.01 |
salt_value |
float | None
|
Value for salt pixels (None=auto-detect). Default: None |
pepper_value |
float | None
|
Value for pepper pixels (None=auto-detect). Default: None |
# |
Poisson mode parameters
|
|
lam_scale |
float
|
Scale factor for Poisson lambda. Higher=more noise. Default: 1.0 |
# |
Common parameters
|
|
clip_range |
tuple[float, float] | None
|
Range for clipping output values. Default: (0.0, 1.0) Set to None for no clipping. |
Note:
Different noise types use different parameters:
- mode="gaussian": Uses noise_std and noise_mean
- mode="salt_pepper": Uses salt_prob, pepper_prob, salt_value, pepper_value
- mode="poisson": Uses lam_scale
mode
class-attribute
instance-attribute
¤
noise_mean
class-attribute
instance-attribute
¤
pepper_prob
class-attribute
instance-attribute
¤
salt_value
class-attribute
instance-attribute
¤
pepper_value
class-attribute
instance-attribute
¤
clip_range
class-attribute
instance-attribute
¤
precomputed_stats
class-attribute
instance-attribute
¤
target_key
class-attribute
instance-attribute
¤
auxiliary_fields
class-attribute
instance-attribute
¤
preserve_auxiliary
class-attribute
instance-attribute
¤
NoiseOperator ¤
NoiseOperator(config: NoiseOperatorConfig, *, rngs: Rngs)
Bases: ModalityOperator
Image noise transformation operator.
Applies noise to images using one of three modes:
- Gaussian: output = input + N(mean, std²)
- Salt & Pepper: Random pixels → salt_value or pepper_value
- Poisson: output = Poisson(input * lam_scale) / lam_scale
Supports three operation modes:
1. **Deterministic**: Fixed noise pattern using fixed seed
2. **Stochastic**: Per-sample random noise from generate_random_params()
3. **External params**: Accept pre-generated random parameters
The operator works on single elements (H, W, C images) and is composed into batch processing via apply_batch() from the base class.
Examples:
Gaussian noise - deterministic:
config = NoiseOperatorConfig(
field_key="image",
mode="gaussian",
noise_std=0.1,
noise_mean=0.0,
stochastic=False
)
operator = NoiseOperator(config, rngs=nnx.Rngs(0))
result, state, metadata = operator.apply(data, state, metadata)
Salt & Pepper noise - stochastic:
config = NoiseOperatorConfig(
field_key="image",
mode="salt_pepper",
salt_prob=0.02,
pepper_prob=0.02,
stochastic=True
)
operator = NoiseOperator(config, rngs=nnx.Rngs(0))
result, state, metadata = operator.apply_batch(batch_data, state, metadata)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
NoiseOperatorConfig
|
Configuration for noise operation |
required |
rngs
|
Rngs
|
RNG streams for stochastic operations |
required |
generate_random_params ¤
Generate random noise for stochastic mode.
In stochastic mode, this pre-generates random noise for the entire batch. This approach avoids RNG state mutations inside vmapped apply().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Array
|
JAX random key |
required |
data_shapes
|
dict[str, tuple[int, ...]]
|
Dictionary mapping field keys to their shapes. Used to determine batch size and element shapes. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dictionary with mode-specific noise data:
|
Raises:
| Type | Description |
|---|---|
KeyError
|
If field_key not in data_shapes |
apply ¤
apply(data: dict[str, Array], state: dict[str, Any], metadata: dict[str, Any], random_params: dict[str, Array] | None = None, stats: dict[str, Any] | None = None) -> tuple[dict[str, Array], dict[str, Any], dict[str, Any]]
Apply noise transformation to a single element.
This operates on single elements (e.g., one image of shape [H, W, C]). For batch processing, use apply_batch() which handles random param generation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Array]
|
Input data dictionary. Must contain field specified by config.field_key |
required |
state
|
dict[str, Any]
|
Operator state (unused for noise, passed through) |
required |
metadata
|
dict[str, Any]
|
Metadata dictionary (passed through unchanged) |
required |
random_params
|
dict[str, Array] | None
|
Optional random parameters from generate_random_params(). If config.stochastic=True and this is provided, uses pre-generated noise/masks. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary (unused) |
None
|
Returns:
| Type | Description |
|---|---|
tuple[dict[str, Array], dict[str, Any], dict[str, Any]]
|
Tuple of (transformed_data, state, metadata) - transformed_data: Data dict with noise applied to target field - state: Unchanged state dict - metadata: Unchanged metadata dict |
Note
CRITICAL: Always check config.stochastic flag, not whether random_params is None. apply_batch() always passes random_params even in deterministic mode.
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
get_output_structure ¤
get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]
Declare output PyTree structure for vmap axis specification.
Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_data
|
PyTree
|
Single element data (not batched) |
required |
sample_state
|
PyTree
|
Single element state (not batched) |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (output_data_structure, output_state_structure) with None leaves. |
PyTree
|
The structure (keys/nesting) matters, leaf values are ignored. |
Example override for operator that adds keys
def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state
apply_batch ¤
Process entire batch with vmap and optional RNG generation.
This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.
The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Input batch (Batch[Element] structure) |
required |
stats
|
dict[str, Any] | None
|
Optional statistics (if None, uses get_statistics()) |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
Transformed batch with same structure |
Note
This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.
output_spec ¤
Return the operator's output spec given an input spec.
Most operators (normalization, additive noise, simple element-wise
transforms) do not change shape; the default returns input_spec
unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST
override this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_spec
|
PyTree
|
PyTree of |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree of |
PyTree
|
By default, equal to |