Skip to content

Typing¤

Type definitions and protocols for Datarax.

See Also¤


datarax.typing ¤

Type definitions for Datarax.

Provides common type aliases, functional interface definitions, and checkpointing protocols used throughout the codebase.

logger module-attribute ¤

logger = getLogger(__name__)

Element module-attribute ¤

Element: TypeAlias = Element

Batch module-attribute ¤

Batch: TypeAlias = Batch

T module-attribute ¤

T = TypeVar('T')

T_co module-attribute ¤

T_co = TypeVar('T_co', covariant=True)

E module-attribute ¤

E = TypeVar('E', bound=Element)

B module-attribute ¤

B = TypeVar('B', bound=Batch)

DataDict module-attribute ¤

DataDict: TypeAlias = dict[str, Array]

StateDict module-attribute ¤

StateDict: TypeAlias = dict[str, Any]

MetadataDict module-attribute ¤

MetadataDict: TypeAlias = dict[str, Any]

ArrayShape module-attribute ¤

ArrayShape: TypeAlias = tuple[int, ...]

PRNGKey module-attribute ¤

PRNGKey: TypeAlias = Array

ElementTransform module-attribute ¤

ElementTransform: TypeAlias = Callable[[Element], Element]

BatchTransform module-attribute ¤

BatchTransform: TypeAlias = Callable[[Batch], Batch]

ArrayTransform module-attribute ¤

ArrayTransform: TypeAlias = Callable[[Array], Array]

DataProcessor module-attribute ¤

DataProcessor: TypeAlias = Callable[[DataDict], DataDict]

StateProcessor module-attribute ¤

StateProcessor: TypeAlias = Callable[[StateDict], StateDict]

MetadataProcessor module-attribute ¤

MetadataProcessor: TypeAlias = Callable[[Metadata], Metadata]

ScanFn module-attribute ¤

CondFn module-attribute ¤

CondFn: TypeAlias = Callable[[Any], bool]

WhileBodyFn module-attribute ¤

WhileBodyFn: TypeAlias = Callable[[Any], Any]

Metadata dataclass ¤

Metadata(index: int = 0, epoch: int = 0, global_step: int = 0, batch_idx: int | None = None, shard_id: int | None = None, rng_key: Array | None = None, _encoded_key: Array | None = None, source_info: dict[str, Any] | None = None, key: InitVar[str | None] = None)

Metadata for tracking experiment state.

Uses custom pytree registration to exclude static fields from tracing. This prevents JIT recompilation when only static fields change.

Dynamic fields (traced):

- index, epoch, global_step, batch_idx, shard_id, rng_key: Numeric tracking fields
- _encoded_key: Byte-encoded key for JAX compatibility

Static fields (not traced):

- source_info: Arbitrary metadata dictionary

index class-attribute instance-attribute ¤

index: int = 0

epoch class-attribute instance-attribute ¤

epoch: int = 0

global_step class-attribute instance-attribute ¤

global_step: int = 0

batch_idx class-attribute instance-attribute ¤

batch_idx: int | None = None

shard_id class-attribute instance-attribute ¤

shard_id: int | None = None

rng_key class-attribute instance-attribute ¤

rng_key: Array | None = None

source_info class-attribute instance-attribute ¤

source_info: dict[str, Any] | None = None

entry_key property ¤

entry_key: str | None

Get record key as string (decodes on demand).

replace ¤

replace(**kwargs) -> Metadata

Create a new Metadata instance with updated fields.

split_rng ¤

split_rng(num: int = 2) -> list[Array | None]

Split RNG key into multiple keys.

next_rng ¤

next_rng() -> Metadata

Get next RNG state.

increment_step ¤

increment_step() -> Metadata

Increment global step.

increment_epoch ¤

increment_epoch() -> Metadata

Increment epoch and reset batch index.

increment_batch ¤

increment_batch() -> Metadata

Increment batch index.

with_shard ¤

with_shard(shard_id: int) -> Metadata

Set shard ID.

to_dict ¤

to_dict() -> dict[str, Any]

Convert to dictionary.

from_dict classmethod ¤

from_dict(data: dict[str, Any]) -> Metadata

Create from dictionary.

merge ¤

merge(other: Metadata | None) -> Metadata

Merge with another metadata, other takes precedence for non-zero/non-None values.

Checkpointable ¤

Bases: Protocol

Protocol for objects that can be checkpointed via state dictionaries.

This protocol defines the interface for objects that support state-based checkpointing, where state is extracted to a dictionary and restored from a dictionary. This aligns with NNX state management patterns.

get_state ¤

get_state() -> dict[str, Any]

Get object state for checkpointing.

Returns:

Type Description
dict[str, Any]

Dictionary containing all state needed to restore the object.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore object state from a checkpoint.

Parameters:

Name Type Description Default
state dict[str, Any]

Dictionary containing state to restore.

required

CheckpointableIterator ¤

Bases: Checkpointable, Protocol[T_co]

Protocol for iterators that can be checkpointed.

Combines Iterator behavior with Checkpointable state management.

get_state ¤

get_state() -> dict[str, Any]

Get object state for checkpointing.

Returns:

Type Description
dict[str, Any]

Dictionary containing all state needed to restore the object.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore object state from a checkpoint.

Parameters:

Name Type Description Default
state dict[str, Any]

Dictionary containing state to restore.

required