Metadata¤
Attach and retrieve metadata on data elements.
See Also¤
- Core Overview - All core protocols
- Config - Configuration system
- Checkpointing - Checkpoint metadata
- Monitoring - Metrics metadata
datarax.core.metadata ¤
Metadata module with proper pytree registration for JAX compatibility.
This module provides metadata for tracking experiment state with proper separation of static and dynamic fields to avoid JIT recompilation.
RecordMetadata
dataclass
¤
RecordMetadata(index: int, record_key: Any, rng_key: Array | None = None, epoch: int = 0, global_step: int = 0, batch_idx: int | None = None, shard_id: int | None = None, source_info: dict[str, Any] | None = None)
Metadata for a record in the data pipeline.
This dataclass tracks metadata about data records as they flow through the pipeline. It's designed to be PyTree-compatible for JAX transformations.
Attributes:
| Name | Type | Description |
|---|---|---|
index |
int
|
Monotonically increasing index for checkpointing |
record_key |
Any
|
Reference to the actual record (file index, offset, etc.) |
rng_key |
Array | None
|
JAX random key for stateless random operations |
epoch |
int
|
Current epoch number |
global_step |
int
|
Global step across all epochs |
batch_idx |
int | None
|
Optional batch index within the current epoch |
shard_id |
int | None
|
Optional shard identifier for distributed processing |
source_info |
dict[str, Any] | None
|
Optional dictionary for source-specific metadata |
Metadata
dataclass
¤
Metadata(index: int = 0, epoch: int = 0, global_step: int = 0, batch_idx: int | None = None, shard_id: int | None = None, rng_key: Array | None = None, _encoded_key: Array | None = None, source_info: dict[str, Any] | None = None, key: InitVar[str | None] = None)
Metadata for tracking experiment state.
Uses custom pytree registration to exclude static fields from tracing. This prevents JIT recompilation when only static fields change.
Dynamic fields (traced):
- index, epoch, global_step, batch_idx, shard_id, rng_key: Numeric tracking fields
- _encoded_key: Byte-encoded key for JAX compatibility
Static fields (not traced):
- source_info: Arbitrary metadata dictionary
MetadataManager ¤
MetadataManager(*, rngs: Rngs | None = None, initial_epoch: int = 0, initial_step: int = 0, track_batches: bool = False, shard_id: int | None = None)
Bases: Module
Utility module for managing metadata state in data sources.
This is a composable utility that data sources can optionally include to track and manage metadata like epochs, global steps, and indices. It's designed to be used via composition rather than inheritance.
Examples:
Example usage:
class MySource(DataSourceModule):
def __init__(self, track_metadata=False, *, rngs: nnx.Rngs):
super().__init__(rngs=rngs)
if track_metadata:
self.metadata_manager = MetadataManager(rngs=rngs)
def __next__(self):
data = self.get_data()
if hasattr(self, 'metadata_manager'):
metadata = self.metadata_manager.create_metadata(record_key=...)
return data, metadata
return data
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rngs
|
Rngs | None
|
Optional Rngs for generating RNG keys in metadata |
None
|
initial_epoch
|
int
|
Starting epoch number |
0
|
initial_step
|
int
|
Starting global step count |
0
|
track_batches
|
bool
|
Whether to track batch indices |
False
|
shard_id
|
int | None
|
Optional shard identifier for distributed processing |
None
|
state
instance-attribute
¤
state = Variable({'global_step': initial_step, 'epoch': initial_epoch, 'index': 0, 'batch_idx': 0 if track_batches else None})
create_metadata ¤
create_metadata(record_key: Any, source_info: dict[str, Any] | None = None) -> RecordMetadata
Create metadata for a record and update internal state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
record_key
|
Any
|
Reference to the actual record |
required |
source_info
|
dict[str, Any] | None
|
Optional source-specific metadata |
None
|
Returns:
| Type | Description |
|---|---|
RecordMetadata
|
RecordMetadata instance with current state |
create_metadata ¤
create_metadata(index: int = 0, epoch: int = 0, global_step: int = 0, batch_idx: int | None = None, shard_id: int | None = None, record_key: str | None = None, source_info: dict[str, Any] | None = None, rng_key: Array | None = None, seed: int | None = None) -> Metadata
Create metadata with optional RNG initialization.
split_rng_tree ¤
Split RNG key into named dictionary.
batch_metadata ¤
Combine multiple metadata into batch metadata.