Typing¤
Type definitions and protocols for Datarax.
See Also¤
- Types & Protocols Overview - All types
- Core - Core protocols
- Config - Config types
- Checkpointing - Checkpointable protocol
datarax.typing ¤
Type definitions for Datarax.
Provides common type aliases, functional interface definitions, and checkpointing protocols used throughout the codebase.
Metadata
dataclass
¤
Metadata(index: int = 0, epoch: int = 0, global_step: int = 0, batch_idx: int | None = None, shard_id: int | None = None, rng_key: Array | None = None, _encoded_key: Array | None = None, source_info: dict[str, Any] | None = None, key: InitVar[str | None] = None)
Metadata for tracking experiment state.
Uses custom pytree registration to exclude static fields from tracing. This prevents JIT recompilation when only static fields change.
Dynamic fields (traced):
- index, epoch, global_step, batch_idx, shard_id, rng_key: Numeric tracking fields
- _encoded_key: Byte-encoded key for JAX compatibility
Static fields (not traced):
- source_info: Arbitrary metadata dictionary
Checkpointable ¤
Bases: Protocol
Protocol for objects that can be checkpointed via state dictionaries.
This protocol defines the interface for objects that support state-based checkpointing, where state is extracted to a dictionary and restored from a dictionary. This aligns with NNX state management patterns.
get_state ¤
CheckpointableIterator ¤
Bases: Checkpointable, Protocol[T_co]
Protocol for iterators that can be checkpointed.
Combines Iterator behavior with Checkpointable state management.