Skip to content

Metadata¤

Attach and retrieve metadata on data elements.

See Also¤


datarax.core.metadata ¤

Metadata module with proper pytree registration for JAX compatibility.

This module provides metadata for tracking experiment state with proper separation of static and dynamic fields to avoid JIT recompilation.

logger module-attribute ¤

logger = getLogger(__name__)

MAX_KEY_LENGTH module-attribute ¤

MAX_KEY_LENGTH = 128

NULL_BYTE module-attribute ¤

NULL_BYTE = 0

RecordMetadata dataclass ¤

RecordMetadata(index: int, record_key: Any, rng_key: Array | None = None, epoch: int = 0, global_step: int = 0, batch_idx: int | None = None, shard_id: int | None = None, source_info: dict[str, Any] | None = None)

Metadata for a record in the data pipeline.

This dataclass tracks metadata about data records as they flow through the pipeline. It's designed to be PyTree-compatible for JAX transformations.

Attributes:

Name Type Description
index int

Monotonically increasing index for checkpointing

record_key Any

Reference to the actual record (file index, offset, etc.)

rng_key Array | None

JAX random key for stateless random operations

epoch int

Current epoch number

global_step int

Global step across all epochs

batch_idx int | None

Optional batch index within the current epoch

shard_id int | None

Optional shard identifier for distributed processing

source_info dict[str, Any] | None

Optional dictionary for source-specific metadata

index instance-attribute ¤

index: int

record_key instance-attribute ¤

record_key: Any

rng_key class-attribute instance-attribute ¤

rng_key: Array | None = None

epoch class-attribute instance-attribute ¤

epoch: int = 0

global_step class-attribute instance-attribute ¤

global_step: int = 0

batch_idx class-attribute instance-attribute ¤

batch_idx: int | None = None

shard_id class-attribute instance-attribute ¤

shard_id: int | None = None

source_info class-attribute instance-attribute ¤

source_info: dict[str, Any] | None = None

split_rng ¤

split_rng(num: int = 2) -> list[Array | None]

Split RNG key into multiple keys.

Metadata dataclass ¤

Metadata(index: int = 0, epoch: int = 0, global_step: int = 0, batch_idx: int | None = None, shard_id: int | None = None, rng_key: Array | None = None, _encoded_key: Array | None = None, source_info: dict[str, Any] | None = None, key: InitVar[str | None] = None)

Metadata for tracking experiment state.

Uses custom pytree registration to exclude static fields from tracing. This prevents JIT recompilation when only static fields change.

Dynamic fields (traced):

- index, epoch, global_step, batch_idx, shard_id, rng_key: Numeric tracking fields
- _encoded_key: Byte-encoded key for JAX compatibility

Static fields (not traced):

- source_info: Arbitrary metadata dictionary

index class-attribute instance-attribute ¤

index: int = 0

epoch class-attribute instance-attribute ¤

epoch: int = 0

global_step class-attribute instance-attribute ¤

global_step: int = 0

batch_idx class-attribute instance-attribute ¤

batch_idx: int | None = None

shard_id class-attribute instance-attribute ¤

shard_id: int | None = None

rng_key class-attribute instance-attribute ¤

rng_key: Array | None = None

source_info class-attribute instance-attribute ¤

source_info: dict[str, Any] | None = None

entry_key property ¤

entry_key: str | None

Get record key as string (decodes on demand).

replace ¤

replace(**kwargs) -> Metadata

Create a new Metadata instance with updated fields.

split_rng ¤

split_rng(num: int = 2) -> list[Array | None]

Split RNG key into multiple keys.

next_rng ¤

next_rng() -> Metadata

Get next RNG state.

increment_step ¤

increment_step() -> Metadata

Increment global step.

increment_epoch ¤

increment_epoch() -> Metadata

Increment epoch and reset batch index.

increment_batch ¤

increment_batch() -> Metadata

Increment batch index.

with_shard ¤

with_shard(shard_id: int) -> Metadata

Set shard ID.

to_dict ¤

to_dict() -> dict[str, Any]

Convert to dictionary.

from_dict classmethod ¤

from_dict(data: dict[str, Any]) -> Metadata

Create from dictionary.

merge ¤

merge(other: Metadata | None) -> Metadata

Merge with another metadata, other takes precedence for non-zero/non-None values.

MetadataManager ¤

MetadataManager(*, rngs: Rngs | None = None, initial_epoch: int = 0, initial_step: int = 0, track_batches: bool = False, shard_id: int | None = None)

Bases: Module

Utility module for managing metadata state in data sources.

This is a composable utility that data sources can optionally include to track and manage metadata like epochs, global steps, and indices. It's designed to be used via composition rather than inheritance.

Examples:

Example usage:

class MySource(DataSourceModule):
    def __init__(self, track_metadata=False, *, rngs: nnx.Rngs):
        super().__init__(rngs=rngs)
        if track_metadata:
            self.metadata_manager = MetadataManager(rngs=rngs)

    def __next__(self):
        data = self.get_data()
        if hasattr(self, 'metadata_manager'):
            metadata = self.metadata_manager.create_metadata(record_key=...)
            return data, metadata
        return data

Parameters:

Name Type Description Default
rngs Rngs | None

Optional Rngs for generating RNG keys in metadata

None
initial_epoch int

Starting epoch number

0
initial_step int

Starting global step count

0
track_batches bool

Whether to track batch indices

False
shard_id int | None

Optional shard identifier for distributed processing

None

rngs instance-attribute ¤

rngs = rngs

shard_id instance-attribute ¤

shard_id = shard_id

track_batches instance-attribute ¤

track_batches = track_batches

state instance-attribute ¤

state = Variable({'global_step': initial_step, 'epoch': initial_epoch, 'index': 0, 'batch_idx': 0 if track_batches else None})

create_metadata ¤

create_metadata(record_key: Any, source_info: dict[str, Any] | None = None) -> RecordMetadata

Create metadata for a record and update internal state.

Parameters:

Name Type Description Default
record_key Any

Reference to the actual record

required
source_info dict[str, Any] | None

Optional source-specific metadata

None

Returns:

Type Description
RecordMetadata

RecordMetadata instance with current state

next_epoch ¤

next_epoch() -> None

Advance to the next epoch.

next_batch ¤

next_batch() -> None

Advance to the next batch.

reset ¤

reset() -> None

Reset all counters to initial values.

create_metadata ¤

create_metadata(index: int = 0, epoch: int = 0, global_step: int = 0, batch_idx: int | None = None, shard_id: int | None = None, record_key: str | None = None, source_info: dict[str, Any] | None = None, rng_key: Array | None = None, seed: int | None = None) -> Metadata

Create metadata with optional RNG initialization.

split_rng_tree ¤

split_rng_tree(metadata: Metadata, num: int = 2) -> dict[str, Array | None]

Split RNG key into named dictionary.

batch_metadata ¤

batch_metadata(metadata_list: list[Metadata]) -> Metadata

Combine multiple metadata into batch metadata.

update_metadata_batch ¤

update_metadata_batch(metadata: Metadata, batch_size: int) -> Metadata

Update metadata after processing a batch.