Checkpoint Handlers¤
Datarax provides checkpoint handlers for saving and restoring pipeline state using Orbax, Google's checkpointing library for JAX. The OrbaxCheckpointHandler offers a high-level interface with automatic handling of PRNG keys and string values.
Key Features¤
| Feature | Description |
|---|---|
| Orbax integration | Built on Orbax StandardCheckpointer |
| PRNG key handling | Automatic serialization of JAX random keys |
| String support | Converts strings to serializable format |
| Versioned checkpoints | Save multiple checkpoints with step numbers |
| Metadata support | Attach custom metadata to checkpoints |
| Context manager | Proper resource cleanup |
★ Insight ─────────────────────────────────────
- Orbax only supports
int,float,np.ndarray, andjax.Array OrbaxCheckpointHandlerautomatically converts PRNG keys and strings- Use versioned checkpoints for training with
stepparameter - Always close handlers or use context manager protocol
─────────────────────────────────────────────────
Quick Start¤
from datarax.checkpoint import OrbaxCheckpointHandler
# Create handler
handler = OrbaxCheckpointHandler()
# Save checkpoint
handler.save("/checkpoints", my_pipeline, step=1000)
# Restore checkpoint
handler.restore("/checkpoints", my_pipeline)
# Cleanup
handler.close()
Context Manager Usage¤
Recommended for proper resource cleanup:
with OrbaxCheckpointHandler() as handler:
handler.save("/checkpoints", pipeline, step=100)
handler.save("/checkpoints", pipeline, step=200)
# Automatic cleanup on exit
Saving Checkpoints¤
Basic Save¤
handler = OrbaxCheckpointHandler()
# Save without version
handler.save("/checkpoints", target)
# Creates: /checkpoints/checkpoint/
# Save with version (step number)
handler.save("/checkpoints", target, step=1000)
# Creates: /checkpoints/ckpt-1000/
With Metadata¤
handler.save(
"/checkpoints",
target,
step=1000,
metadata={
"epoch": 10,
"loss": 0.05,
"config": {"lr": 1e-3},
},
)
Checkpoint Retention¤
Control how many checkpoints to keep:
Overwriting¤
Restoring Checkpoints¤
Basic Restore¤
# Restore latest checkpoint
state = handler.restore("/checkpoints")
# Restore specific step
state = handler.restore("/checkpoints", step=1000)
# Restore into existing object
handler.restore("/checkpoints", target=my_pipeline)
Restore Metadata Only¤
metadata = handler.restore(
"/checkpoints",
metadata_only=True,
)
print(f"Epoch: {metadata['epoch']}")
Working with Datarax Objects¤
Pipeline Checkpointing¤
from datarax.pipeline import Pipeline
pipeline = Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0))
# Train for a while...
for step, batch in enumerate(pipeline):
loss = train_step(batch)
if step % 1000 == 0:
handler.save("/checkpoints", pipeline, step=step)
# Later: restore and continue
handler.restore("/checkpoints", pipeline)
NNX Module Checkpointing¤
import flax.nnx as nnx
class MyModel(nnx.Module):
...
model = MyModel()
# Save model state
handler.save("/checkpoints", model, step=1000)
# Restore into model
handler.restore("/checkpoints", model)
Checkpointable Protocol¤
Any object implementing the Checkpointable protocol can be saved:
from datarax.typing import Checkpointable
class MyCheckpointable:
def get_state(self) -> dict:
return {"my_data": self.data}
def set_state(self, state: dict) -> None:
self.data = state["my_data"]
obj = MyCheckpointable()
handler.save("/checkpoints", obj)
Checkpoint Management¤
List Checkpoints¤
# Get all checkpoint steps
steps = handler.get_checkpoint_steps("/checkpoints")
# [100, 200, 300, 400, 500]
# Get latest step
latest = handler.latest_step("/checkpoints")
# 500
# List all checkpoints with paths
checkpoints = handler.list_checkpoints("/checkpoints")
# {100: '/checkpoints/ckpt-100', 200: '/checkpoints/ckpt-200', ...}
PRNG Key Handling¤
PRNG keys are automatically serialized:
import jax
state = {
"params": model_params,
"rng_key": jax.random.key(42), # Automatically handled
}
handler.save("/checkpoints", state)
# Keys are restored as proper JAX PRNG keys
restored = handler.restore("/checkpoints")
new_key = jax.random.split(restored["rng_key"])
String Handling¤
Strings are converted to character codes for serialization:
state = {
"model_name": "my_model_v2",
"config_json": '{"lr": 0.001}',
}
handler.save("/checkpoints", state)
restored = handler.restore("/checkpoints")
assert restored["model_name"] == "my_model_v2"
See Also¤
- Checkpointing Guide - Complete checkpointing tutorial
- Checkpoint Quick Reference
- DAG Executor - Pipeline checkpointing
- Orbax Documentation
API Reference¤
datarax.checkpoint.handlers ¤
Checkpoint handlers for Datarax.
This module provides checkpoint handlers for Datarax components using Orbax. Follows Orbax patterns - leverages StandardCheckpointHandler and PyTreeCheckpointer rather than reimplementing serialization logic.
References: - orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py - orbax/checkpoint/_src/handlers/random_key_checkpoint_handler.py
OrbaxCheckpointHandler ¤
OrbaxCheckpointHandler(async_checkpointing: bool = False)
Checkpoint handler for Datarax components using Orbax.
This handler provides a high-level interface to Orbax checkpoint capabilities, leveraging StandardCheckpointer for PyTree serialization rather than reimplementing serialization logic.
Following Orbax patterns from:
- standard_checkpoint_handler.py for PyTree checkpointing
- random_key_checkpoint_handler.py for PRNG key handling
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
async_checkpointing
|
bool
|
If True, save() returns immediately and serialization happens in the background. Call wait_until_finished() before restore() or before the next save() if you need ordering guarantees. |
False
|
wait_until_finished ¤
Block until any outstanding async save completes.
save_to_directory ¤
save_to_directory(directory: str | Path, target: Any, step: int | None = None, keep: int | None = 1, overwrite: bool = False, metadata: dict[str, Any] | None = None) -> str
Save a checkpoint to a directory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
directory
|
str | Path
|
Directory to save to. |
required |
target
|
Any
|
Object to checkpoint (Checkpointable, dict, or PyTree). |
required |
step
|
int | None
|
Optional step number for versioned checkpoints. |
None
|
keep
|
int | None
|
Number of checkpoints to keep (for versioned checkpoints). |
1
|
overwrite
|
bool
|
Whether to overwrite existing checkpoints. |
False
|
metadata
|
dict[str, Any] | None
|
Optional metadata to save with the checkpoint. |
None
|
Returns:
| Type | Description |
|---|---|
str
|
Path to the saved checkpoint. |
restore ¤
restore(directory: str | Path, target: Any | None = None, step: int | None = None, metadata_only: bool = False, _restore_args: Any | None = None) -> Any
Restore from a checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
directory
|
str | Path
|
Directory to restore from. |
required |
target
|
Any | None
|
Optional target to restore into. |
None
|
step
|
int | None
|
Optional step to restore from (None = latest). |
None
|
metadata_only
|
bool
|
If True, only return metadata. |
False
|
_restore_args
|
Any | None
|
Reserved for Orbax interoperability. |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
The restored object, state, or metadata. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If directory doesn't exist or no checkpoints found. |
get_checkpoint_steps ¤
Get all checkpoint steps in a directory.
list_checkpoints ¤
List all checkpoints in a directory.
close ¤
Close the checkpoint handler and release resources.
This method waits for any outstanding async operations to finish and properly cleans up the underlying checkpointer. Should be called when done with checkpointing, or use the context manager protocol.