Modality¤
Data modality definitions (image, text, audio).
See Also¤
- Core Overview - All core protocols
- Cross Modal - Multi-modal processing
- Operators - Modality-specific operators
- Data Sources - Multi-modal sources
datarax.core.modality ¤
Modality-specific operator base classes.
This module provides base classes for operators that work on single modalities (specific fields within elements). Each ModalityOperator handles ONE field (e.g., image, text, audio) and can have learnable parameters via Flax NNX.
Key Features:
- Single-field transformations with coordinate auxiliary fields
- Value domain constraints and clipping
- Learnable parameters support via Flax NNX
- Compatible with JAX transformations (jit, vmap, grad)
- End-to-end differentiable data pipelines
Examples:
Examples: Deterministic image operator:
```python
config = ModalityOperatorConfig(field_key="image", clip_range=(0.0, 1.0))
# Note: Use specific operators like BrightnessOperator, ContrastOperator, etc.
```
Stochastic audio operator with learnable parameters:
```python
config = ModalityOperatorConfig(
field_key="waveform",
stochastic=True,
stream_name="augment"
)
operator = LearnedAudioOperator(config, rngs=nnx.Rngs(0, augment=1))
```
ModalityOperatorConfig
dataclass
¤
ModalityOperatorConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, batch_strategy: str = 'vmap', *, field_key: str, target_key: str | None = None, auxiliary_fields: list[str] | None = None, clip_range: tuple[float, float] | None = None, preserve_auxiliary: bool = True, validate_domain_constraints: bool = True)
Bases: OperatorConfig
Configuration for modality-specific operators.
Captures common patterns across modalities:
- Field identification (primary + auxiliary)
- Value domain constraints
- Transformation coordination settings
- Learnable parameter support via Flax NNX
Each ModalityOperator handles a SINGLE modality (specific field in Element). Multi-modal data handled by composing multiple ModalityOperators via CompositeOperatorModule or MultiModalOperator.
IMPORTANT: Operators can have learnable parameters and must be compatible with JAX transformations (jit, vmap, grad).
Attributes:
| Name | Type | Description |
|---|---|---|
field_key |
str
|
Primary data field to transform (e.g., "image", "caption", "waveform") |
target_key |
str | None
|
Optional target field (if None, overwrites source field) |
auxiliary_fields |
list[str] | None
|
Fields that should be transformed coordinately (e.g., ["mask", "bounding_boxes"] for image) |
clip_range |
tuple[float, float] | None
|
Value domain constraints as (min, max) tuple (None = no clipping) |
preserve_auxiliary |
bool
|
Whether to preserve auxiliary data structure during transformation |
validate_domain_constraints |
bool
|
Enable domain-specific validation rules (e.g., biological validity, clinical plausibility) |
Validation Rules:
- field_key must be non-empty string
- clip_range must be tuple of (min, max) where min < max
- Inherits stochastic validation from OperatorConfig
- Inherits statistics validation from DataraxModuleConfig
Examples:
Minimal image operator:
Image with clipping and auxiliary fields:
config = ModalityOperatorConfig(
field_key="image",
auxiliary_fields=["mask", "bounding_boxes"],
clip_range=(0.0, 1.0)
)
Stochastic text operator:
target_key
class-attribute
instance-attribute
¤
auxiliary_fields
class-attribute
instance-attribute
¤
clip_range
class-attribute
instance-attribute
¤
preserve_auxiliary
class-attribute
instance-attribute
¤
validate_domain_constraints
class-attribute
instance-attribute
¤
precomputed_stats
class-attribute
instance-attribute
¤
ModalityOperator ¤
ModalityOperator(config: ModalityOperatorConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: OperatorModule
Base class for modality-specific operators with learnable parameter support.
Provides common functionality:
- Field extraction and remapping
- Value range validation and clipping
- Auxiliary field coordination
- Integration with MapOperator for transformations
- Learnable parameter support via Flax NNX
Subclasses (ImageOperator, AudioOperator, etc.) provide:
- Modality-specific configurations
- Domain-specific transformation functions
- Validation logic for modality data
- Learnable parameters (e.g., augmentation strategies, normalization params)
Key Features:
- Compatible with nnx.jit, jax.vmap, jax.grad
- Supports learnable parameters via nnx.Param
- End-to-end differentiable data pipelines
- Can be optimized jointly with model
- Operates on Batch[Element] (inherited from OperatorModule)
Inherited Features from OperatorModule:
- **apply_batch()**: Automatically handles batched operations by calling apply()
on each element. Override only if you need custom batch-level logic (e.g.,
batch normalization, cross-element operations). Default is sufficient for
most element-wise transformations.
- **Statistics system**: Optionally collect and use batch statistics via stats
parameter in apply(). Useful for adaptive operations (e.g., batch-aware
normalization). Statistics are computed externally and passed in.
- **Caching system**: Results can be cached based on operator configuration
and input characteristics. Inherited from base OperatorModule, helps avoid
redundant computation for deterministic operators.
Subclass Implementation Pattern
class ImageOperator(ModalityOperator):
def __init__(self, config: ModalityOperatorConfig, *, rngs: nnx.Rngs | None = None):
super().__init__(config, rngs=rngs)
# Add learnable parameters if needed
# self.augment_strength = nnx.Param(jnp.array(0.5))
def apply(self, data, state, metadata, random_params=None, stats=None):
# Extract field
image = self._extract_field(data, self.config.field_key)
# Transform (can use learnable parameters)
transformed = self._transform_image(image)
# Apply clipping
transformed = self._apply_clip_range(transformed)
# Remap to target field
result = self._remap_field(data, transformed)
return result, state, metadata
def generate_random_params(self, rng, data_shapes):
# For stochastic operators only
batch_size = data_shapes[self.config.field_key][0]
return jax.random.uniform(rng, (batch_size,))
Examples:
Deterministic operator (no learnable params):
Learnable operator (learned augmentation strategy):
class LearnedImageOperator(ImageOperator):
def __init__(self, config, *, rngs):
super().__init__(config, rngs=rngs)
# Learnable augmentation parameters
self.crop_scale = nnx.Param(jnp.array(0.8))
self.rotation_angle = nnx.Param(jnp.array(0.1))
Custom batch-level operator (rare, only when needed):
class BatchNormOperator(ImageOperator):
def apply_batch(self, batch, stats=None):
# Override for batch-level normalization
# Compute batch statistics here
# Call apply() for each element with shared stats
pass
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ModalityOperatorConfig
|
Modality operator configuration (already validated) |
required |
rngs
|
Rngs | None
|
Random number generators (required if stochastic=True) |
None
|
name
|
str | None
|
Optional operator name |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If stochastic=True but rngs is None |
apply ¤
apply(data: PyTree, state: PyTree, metadata: dict[str, Any] | None, random_params: Any = None, stats: dict[str, Any] | None = None) -> tuple[PyTree, PyTree, dict[str, Any] | None]
Apply modality-specific transformation to element.
MUST be implemented by subclasses to provide modality-specific behavior.
This is a PURE FUNCTION that transforms a single data element. It should not access self.rngs or generate random numbers. All randomness comes through random_params argument.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
PyTree
|
Element data PyTree (contains field specified by config.field_key) Typically dict[str, Array] with no batch dimension |
required |
state
|
PyTree
|
Element state PyTree (typically dict[str, Any]) |
required |
metadata
|
dict[str, Any] | None
|
Element metadata dict |
required |
random_params
|
Any
|
Random parameters for this element (from generate_random_params) |
None
|
stats
|
dict[str, Any] | None
|
Optional batch statistics (from get_statistics() or passed explicitly) |
None
|
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (transformed_data, new_state, new_metadata) |
PyTree
|
|
dict[str, Any] | None
|
|
tuple[PyTree, PyTree, dict[str, Any] | None]
|
|
Implementation Pattern
def apply(self, data, state, metadata, random_params=None, stats=None):
# 1. Extract field
field_value = self._extract_field(data, self.config.field_key)
# 2. Transform (modality-specific logic)
transformed = self._transform(field_value, random_params, stats)
# 3. Apply clipping if configured
transformed = self._apply_clip_range(transformed)
# 4. Remap to target field
result = self._remap_field(data, transformed)
return result, state, metadata
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If not implemented by subclass |
generate_random_params ¤
generate_random_params(rng: Array, data_shapes: PyTree) -> PyTree
Generate random parameters for stochastic transformations.
MUST be implemented by stochastic operators (config.stochastic=True). Deterministic operators can use default implementation (returns None).
Generates PyTree of random parameters matching batch structure. For example, image rotation might generate per-element rotation angles.
This method is impure (uses RNG) and called once per batch. The generated parameters are then passed to apply() for each element via vmap.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Array
|
JAX random key for this batch |
required |
data_shapes
|
PyTree
|
PyTree with same structure as batch.data, containing shapes Examples: {"image": (batch_size, H, W, C)} |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree of random parameters for this batch. |
PyTree
|
Structure depends on operator needs. |
PyTree
|
For deterministic operators, returns None. |
Examples:
def generate_random_params(self, rng, data_shapes):
# Stochastic rotation: generate per-element angles
batch_size = data_shapes[self.config.field_key][0]
return jax.random.uniform(rng, (batch_size,), minval=0, maxval=2*jnp.pi)
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If stochastic=True but not implemented |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
get_state ¤
Get module state for checkpointing.
This method implements the Checkpointable protocol using NNX state management. It extracts all state variables from the module and converts them to a serializable format.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the internal state of the component. |
set_state ¤
Restore module state from a checkpoint.
This method implements the Checkpointable protocol using NNX state management. It restores the module state from a serialized format. Restoration is strict: checkpoint structure must match module state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
dict[str, Any]
|
A dictionary containing the internal state to restore. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If state is not a dictionary. |
ValueError
|
If checkpoint structure does not match module state. |
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
get_output_structure ¤
get_output_structure(sample_data: PyTree, sample_state: PyTree) -> tuple[PyTree, PyTree]
Declare output PyTree structure for vmap axis specification.
Default uses jax.eval_shape to discover structure automatically. Override for efficiency or when eval_shape doesn't work (e.g., data-dependent shapes).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_data
|
PyTree
|
Single element data (not batched) |
required |
sample_state
|
PyTree
|
Single element state (not batched) |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Tuple of (output_data_structure, output_state_structure) with None leaves. |
PyTree
|
The structure (keys/nesting) matters, leaf values are ignored. |
Example override for operator that adds keys
def get_output_structure(self, sample_data, sample_state): out_data = { **jax.tree.map(lambda _: None, sample_data), "score": None, "alignment": None, } return out_data, sample_state
apply_batch ¤
Process entire batch with vmap and optional RNG generation.
This method implements the batch processing logic for both stochastic and deterministic modes. It uses static branching on self.stochastic for JIT compilation efficiency.
The implementation delegates to _vmap_apply() for the shared computational core, then wraps the result in a Batch object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
Input batch (Batch[Element] structure) |
required |
stats
|
dict[str, Any] | None
|
Optional statistics (if None, uses get_statistics()) |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
Transformed batch with same structure |
Note
This method is concrete (not abstract). Subclasses typically don't override it, but can if they need custom batch processing logic.
output_spec ¤
Return the operator's output spec given an input spec.
Most operators (normalization, additive noise, simple element-wise
transforms) do not change shape; the default returns input_spec
unchanged. Shape-changing operators (Resize, Crop, Reshape) MUST
override this method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_spec
|
PyTree
|
PyTree of |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
PyTree of |
PyTree
|
By default, equal to |