Skip to content

Modality¤

Data modality definitions (image, text, audio).

See Also¤


datarax.core.modality ¤

Modality-specific operator base classes.

This module provides base classes for operators that work on single modalities (specific fields within elements). Each ModalityOperator handles ONE field (e.g., image, text, audio) and can have learnable parameters via Flax NNX.

Key Features:

  • Single-field transformations with coordinate auxiliary fields
  • Value domain constraints and clipping
  • Learnable parameters support via Flax NNX
  • Compatible with JAX transformations (jit, vmap, grad)
  • End-to-end differentiable data pipelines

Examples:

Examples: Deterministic image operator:

```python
config = ModalityOperatorConfig(field_key="image", clip_range=(0.0, 1.0))
# Note: Use specific operators like BrightnessOperator, ContrastOperator, etc.
```

Stochastic audio operator with learnable parameters:


```python
config = ModalityOperatorConfig(
    field_key="waveform",
    stochastic=True,
    stream_name="augment"
)
operator = LearnedAudioOperator(config, rngs=nnx.Rngs(0, augment=1))
```

logger module-attribute ¤

logger = getLogger(__name__)

ModalityOperatorConfig dataclass ¤

ModalityOperatorConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, batch_strategy: str = 'vmap', *, field_key: str, target_key: str | None = None, auxiliary_fields: list[str] | None = None, clip_range: tuple[float, float] | None = None, preserve_auxiliary: bool = True, validate_domain_constraints: bool = True)

Bases: OperatorConfig

Configuration for modality-specific operators.

Captures common patterns across modalities:

- Field identification (primary + auxiliary)
- Value domain constraints
- Transformation coordination settings
- Learnable parameter support via Flax NNX

Each ModalityOperator handles a SINGLE modality (specific field in Element). Multi-modal data handled by composing multiple ModalityOperators via CompositeOperatorModule or MultiModalOperator.

IMPORTANT: Operators can have learnable parameters and must be compatible with JAX transformations (jit, vmap, grad).

Attributes:

Name Type Description
field_key str

Primary data field to transform (e.g., "image", "caption", "waveform")

target_key str | None

Optional target field (if None, overwrites source field)

auxiliary_fields list[str] | None

Fields that should be transformed coordinately (e.g., ["mask", "bounding_boxes"] for image)

clip_range tuple[float, float] | None

Value domain constraints as (min, max) tuple (None = no clipping)

preserve_auxiliary bool

Whether to preserve auxiliary data structure during transformation

validate_domain_constraints bool

Enable domain-specific validation rules (e.g., biological validity, clinical plausibility)

Validation Rules:

- field_key must be non-empty string
- clip_range must be tuple of (min, max) where min < max
- Inherits stochastic validation from OperatorConfig
- Inherits statistics validation from DataraxModuleConfig

Examples:

Minimal image operator:

config = ModalityOperatorConfig(field_key="image")

Image with clipping and auxiliary fields:

config = ModalityOperatorConfig(
    field_key="image",
    auxiliary_fields=["mask", "bounding_boxes"],
    clip_range=(0.0, 1.0)
)

Stochastic text operator:

config = ModalityOperatorConfig(
    field_key="caption",
    stochastic=True,
    stream_name="text_augment"
)

field_key class-attribute instance-attribute ¤

field_key: str = field(kw_only=True)

target_key class-attribute instance-attribute ¤

target_key: str | None = field(default=None, kw_only=True)

auxiliary_fields class-attribute instance-attribute ¤

auxiliary_fields: list[str] | None = field(default=None, kw_only=True)

clip_range class-attribute instance-attribute ¤

clip_range: tuple[float, float] | None = field(default=None, kw_only=True)

preserve_auxiliary class-attribute instance-attribute ¤

preserve_auxiliary: bool = field(default=True, kw_only=True)

validate_domain_constraints class-attribute instance-attribute ¤

validate_domain_constraints: bool = field(default=True, kw_only=True)

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

stochastic class-attribute instance-attribute ¤

stochastic: bool = False

stream_name class-attribute instance-attribute ¤

stream_name: str | None = None

batch_strategy class-attribute instance-attribute ¤

batch_strategy: str = 'vmap'

ModalityOperator ¤

ModalityOperator(config: ModalityOperatorConfig, *, rngs: Rngs | None = None, name: str | None = None)

Bases: OperatorModule

Base class for modality-specific operators with learnable parameter support.

Provides common functionality:

- Field extraction and remapping
- Value range validation and clipping
- Auxiliary field coordination
- Integration with MapOperator for transformations
- Learnable parameter support via Flax NNX

Subclasses (ImageOperator, AudioOperator, etc.) provide:

- Modality-specific configurations
- Domain-specific transformation functions
- Validation logic for modality data
- Learnable parameters (e.g., augmentation strategies, normalization params)

Key Features:

- Compatible with nnx.jit, jax.vmap, jax.grad
- Supports learnable parameters via nnx.Param
- End-to-end differentiable data pipelines
- Can be optimized jointly with model
- Operates on Batch[Element] (inherited from OperatorModule)

Inherited Features from OperatorModule:

- **apply_batch()**: Automatically handles batched operations by calling apply()
  on each element. Override only if you need custom batch-level logic (e.g.,
  batch normalization, cross-element operations). Default is sufficient for
  most element-wise transformations.

- **Statistics system**: Optionally collect and use batch statistics via stats
  parameter in apply(). Useful for adaptive operations (e.g., batch-aware
  normalization). Statistics are computed externally and passed in.

- **Caching system**: Results can be cached based on operator configuration
  and input characteristics. Inherited from base OperatorModule, helps avoid
  redundant computation for deterministic operators.
Subclass Implementation Pattern
class ImageOperator(ModalityOperator):
    def __init__(self, config: ModalityOperatorConfig, *, rngs: nnx.Rngs | None = None):
        super().__init__(config, rngs=rngs)
        # Add learnable parameters if needed
        # self.augment_strength = nnx.Param(jnp.array(0.5))

    def apply(self, data, state, metadata, random_params=None, stats=None):
        # Extract field
        image = self._extract_field(data, self.config.field_key)

        # Transform (can use learnable parameters)
        transformed = self._transform_image(image)

        # Apply clipping
        transformed = self._apply_clip_range(transformed)

        # Remap to target field
        result = self._remap_field(data, transformed)

        return result, state, metadata

    def generate_random_params(self, rng, data_shapes):
        # For stochastic operators only
        batch_size = data_shapes[self.config.field_key][0]
        return jax.random.uniform(rng, (batch_size,))

Examples:

Deterministic operator (no learnable params):

image_op = ImageOperator(config, rngs=nnx.Rngs(0))

Learnable operator (learned augmentation strategy):

class LearnedImageOperator(ImageOperator):
    def __init__(self, config, *, rngs):
        super().__init__(config, rngs=rngs)
        # Learnable augmentation parameters
        self.crop_scale = nnx.Param(jnp.array(0.8))
        self.rotation_angle = nnx.Param(jnp.array(0.1))

Custom batch-level operator (rare, only when needed):

class BatchNormOperator(ImageOperator):
    def apply_batch(self, batch, stats=None):
        # Override for batch-level normalization
        # Compute batch statistics here
        # Call apply() for each element with shared stats
        pass

Parameters:

Name Type Description Default
config ModalityOperatorConfig

Modality operator configuration (already validated)

required
rngs Rngs | None

Random number generators (required if stochastic=True)

None
name str | None

Optional operator name

None

Raises:

Type Description
ValueError

If stochastic=True but rngs is None

config instance-attribute ¤

config: ModalityOperatorConfig = config

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

stochastic instance-attribute ¤

stochastic = static(stochastic)

stream_name instance-attribute ¤

stream_name = static(stream_name)

apply ¤

apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Any = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]

Apply modality-specific transformation to element.

MUST be implemented by subclasses to provide modality-specific behavior.

This is a PURE FUNCTION that transforms a single data element. It should not access self.rngs or generate random numbers. All randomness comes through random_params argument.

Parameters:

Name Type Description Default
data PyTree

Element data PyTree (contains field specified by config.field_key) Typically dict[str, Array] with no batch dimension

required
state PyTree

Element state PyTree (typically dict[str, Any])

required
metadata dict[str, Any] | None

Element metadata dict

required
random_params Any

Random parameters for this element (from generate_random_params)

None
stats dict[str, Any] | None

Optional batch statistics (from get_statistics() or passed explicitly)

None

Returns:

Type Description
PyTree

Tuple of (transformed_data, new_state, new_metadata)

PyTree
  • transformed_data: PyTree with same structure as data, containing transformed field
dict[str, Any] | None
  • new_state: Updated state PyTree
tuple[PyTree, PyTree, dict[str, Any] | None]
  • new_metadata: Updated metadata dict
Implementation Pattern
def apply(self, data, state, metadata, random_params=None, stats=None):
    # 1. Extract field
    field_value = self._extract_field(data, self.config.field_key)

    # 2. Transform (modality-specific logic)
    transformed = self._transform(field_value, random_params, stats)

    # 3. Apply clipping if configured
    transformed = self._apply_clip_range(transformed)

    # 4. Remap to target field
    result = self._remap_field(data, transformed)

    return result, state, metadata

Raises:

Type Description
NotImplementedError

If not implemented by subclass

generate_random_params ¤

generate_random_params(rng: Array, data_shapes: PyTree) -> PyTree

Generate random parameters for stochastic transformations.

MUST be implemented by stochastic operators (config.stochastic=True). Deterministic operators can use default implementation (returns None).

Generates PyTree of random parameters matching batch structure. For example, image rotation might generate per-element rotation angles.

This method is impure (uses RNG) and called once per batch. The generated parameters are then passed to apply() for each element via vmap.

Parameters:

Name Type Description Default
rng Array

JAX random key for this batch

required
data_shapes PyTree

PyTree with same structure as batch.data, containing shapes Examples: {"image": (batch_size, H, W, C)}

required

Returns:

Type Description
PyTree

PyTree of random parameters for this batch.

PyTree

Structure depends on operator needs.

PyTree

For deterministic operators, returns None.

Examples:

def generate_random_params(self, rng, data_shapes):
    # Stochastic rotation: generate per-element angles
    batch_size = data_shapes[self.config.field_key][0]
    return jax.random.uniform(rng, (batch_size,), minval=0, maxval=2*jnp.pi)

Raises:

Type Description
NotImplementedError

If stochastic=True but not implemented

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

get_state ¤

get_state() -> dict[str, Any]

Get module state for checkpointing.

This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the component.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore module state from a checkpoint.

This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

Raises:

Type Description
TypeError

If state is not a dictionary.

ValueError

If checkpoint structure does not match module state.

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.

get_output_structure ¤

get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]

Declare output PyTree structure for vmap axis specification.

Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).

Parameters:

Name Type Description Default
sample_data PyTree

Single element data (not batched)

required
sample_state PyTree

Single element state (not batched)

required

Returns:

Type Description
PyTree

Tuple of (output_data_structure, output_state_structure) with None leaves.

PyTree

The structure (keys/nesting) matters, leaf values are ignored.

Example override for operator that adds keys

def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state

apply_batch ¤

apply_batch(batch: Batch, stats: dict[str, Any] | None = None) -> Batch

Process entire batch with vmap and optional RNG generation.

This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.

The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.

Parameters:

Name Type Description Default
batch Batch

Input batch (Batch[Element] structure)

required
stats dict[str, Any] | None

Optional statistics (if None, uses get_statistics())

None

Returns:

Type Description
Batch

Transformed batch with same structure

Note

This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.

output_spec ¤

output_spec(input_spec: PyTree) -> PyTree

Return the operator's output spec given an input spec.

Most operators (normalization, additive noise, simple element-wise transforms) do not change shape; the default returns input_spec unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST override this method.

Parameters:

Name Type Description Default
input_spec PyTree

PyTree of jax.ShapeDtypeStruct describing the input element (matching the upstream DataSourceModule.element_spec() or another operator's output_spec).

required

Returns:

Type Description
PyTree

PyTree of jax.ShapeDtypeStruct describing the operator's output.

PyTree

By default, equal to input_spec.