Skip to content

Checkpoint Handlers¤

Datarax provides checkpoint handlers for saving and restoring pipeline state using Orbax, Google's checkpointing library for JAX. The OrbaxCheckpointHandler offers a high-level interface with automatic handling of PRNG keys and string values.

Key Features¤

Feature Description
Orbax integration Built on Orbax StandardCheckpointer
PRNG key handling Automatic serialization of JAX random keys
String support Converts strings to serializable format
Versioned checkpoints Save multiple checkpoints with step numbers
Metadata support Attach custom metadata to checkpoints
Context manager Proper resource cleanup

★ Insight ─────────────────────────────────────

  • Orbax only supports int, float, np.ndarray, and jax.Array
  • OrbaxCheckpointHandler automatically converts PRNG keys and strings
  • Use versioned checkpoints for training with step parameter
  • Always close handlers or use context manager protocol

─────────────────────────────────────────────────

Quick Start¤

from datarax.checkpoint import OrbaxCheckpointHandler

# Create handler
handler = OrbaxCheckpointHandler()

# Save checkpoint
handler.save("/checkpoints", my_pipeline, step=1000)

# Restore checkpoint
handler.restore("/checkpoints", my_pipeline)

# Cleanup
handler.close()

Context Manager Usage¤

Recommended for proper resource cleanup:

with OrbaxCheckpointHandler() as handler:
    handler.save("/checkpoints", pipeline, step=100)
    handler.save("/checkpoints", pipeline, step=200)
    # Automatic cleanup on exit

Saving Checkpoints¤

Basic Save¤

handler = OrbaxCheckpointHandler()

# Save without version
handler.save("/checkpoints", target)
# Creates: /checkpoints/checkpoint/

# Save with version (step number)
handler.save("/checkpoints", target, step=1000)
# Creates: /checkpoints/ckpt-1000/

With Metadata¤

handler.save(
    "/checkpoints",
    target,
    step=1000,
    metadata={
        "epoch": 10,
        "loss": 0.05,
        "config": {"lr": 1e-3},
    },
)

Checkpoint Retention¤

Control how many checkpoints to keep:

handler.save(
    "/checkpoints",
    target,
    step=1000,
    keep=5,  # Keep only last 5 checkpoints
)

Overwriting¤

handler.save(
    "/checkpoints",
    target,
    overwrite=True,  # Overwrite existing checkpoint
)

Restoring Checkpoints¤

Basic Restore¤

# Restore latest checkpoint
state = handler.restore("/checkpoints")

# Restore specific step
state = handler.restore("/checkpoints", step=1000)

# Restore into existing object
handler.restore("/checkpoints", target=my_pipeline)

Restore Metadata Only¤

metadata = handler.restore(
    "/checkpoints",
    metadata_only=True,
)
print(f"Epoch: {metadata['epoch']}")

Working with Datarax Objects¤

Pipeline Checkpointing¤

from datarax.pipeline import Pipeline

pipeline = Pipeline(source=source, stages=[], batch_size=32, rngs=nnx.Rngs(0))

# Train for a while...
for step, batch in enumerate(pipeline):
    loss = train_step(batch)

    if step % 1000 == 0:
        handler.save("/checkpoints", pipeline, step=step)

# Later: restore and continue
handler.restore("/checkpoints", pipeline)

NNX Module Checkpointing¤

import flax.nnx as nnx

class MyModel(nnx.Module):
    ...

model = MyModel()

# Save model state
handler.save("/checkpoints", model, step=1000)

# Restore into model
handler.restore("/checkpoints", model)

Checkpointable Protocol¤

Any object implementing the Checkpointable protocol can be saved:

from datarax.typing import Checkpointable

class MyCheckpointable:
    def get_state(self) -> dict:
        return {"my_data": self.data}

    def set_state(self, state: dict) -> None:
        self.data = state["my_data"]

obj = MyCheckpointable()
handler.save("/checkpoints", obj)

Checkpoint Management¤

List Checkpoints¤

# Get all checkpoint steps
steps = handler.get_checkpoint_steps("/checkpoints")
# [100, 200, 300, 400, 500]

# Get latest step
latest = handler.latest_step("/checkpoints")
# 500

# List all checkpoints with paths
checkpoints = handler.list_checkpoints("/checkpoints")
# {100: '/checkpoints/ckpt-100', 200: '/checkpoints/ckpt-200', ...}

PRNG Key Handling¤

PRNG keys are automatically serialized:

import jax

state = {
    "params": model_params,
    "rng_key": jax.random.key(42),  # Automatically handled
}

handler.save("/checkpoints", state)

# Keys are restored as proper JAX PRNG keys
restored = handler.restore("/checkpoints")
new_key = jax.random.split(restored["rng_key"])

String Handling¤

Strings are converted to character codes for serialization:

state = {
    "model_name": "my_model_v2",
    "config_json": '{"lr": 0.001}',
}

handler.save("/checkpoints", state)

restored = handler.restore("/checkpoints")
assert restored["model_name"] == "my_model_v2"

See Also¤


API Reference¤

datarax.checkpoint.handlers ¤

Checkpoint handlers for Datarax.

This module provides checkpoint handlers for Datarax components using Orbax. Follows Orbax patterns - leverages StandardCheckpointHandler and PyTreeCheckpointer rather than reimplementing serialization logic.

References: - orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py - orbax/checkpoint/_src/handlers/random_key_checkpoint_handler.py

logger module-attribute ¤

logger = getLogger(__name__)

OrbaxCheckpointHandler ¤

OrbaxCheckpointHandler(async_checkpointing: bool = False)

Checkpoint handler for Datarax components using Orbax.

This handler provides a high-level interface to Orbax checkpoint capabilities, leveraging StandardCheckpointer for PyTree serialization rather than reimplementing serialization logic.

Following Orbax patterns from:

  • standard_checkpoint_handler.py for PyTree checkpointing
  • random_key_checkpoint_handler.py for PRNG key handling

Parameters:

Name Type Description Default
async_checkpointing bool

If True, save() returns immediately and serialization happens in the background. Call wait_until_finished() before restore() or before the next save() if you need ordering guarantees.

False

async_checkpointing instance-attribute ¤

async_checkpointing = async_checkpointing

checkpointer instance-attribute ¤

checkpointer = AsyncCheckpointer(StandardCheckpointHandler())

wait_until_finished ¤

wait_until_finished() -> None

Block until any outstanding async save completes.

save_to_directory ¤

save_to_directory(directory: str | Path, target: Any, step: int | None = None, keep: int | None = 1, overwrite: bool = False, metadata: dict[str, Any] | None = None) -> str

Save a checkpoint to a directory.

Parameters:

Name Type Description Default
directory str | Path

Directory to save to.

required
target Any

Object to checkpoint (Checkpointable, dict, or PyTree).

required
step int | None

Optional step number for versioned checkpoints.

None
keep int | None

Number of checkpoints to keep (for versioned checkpoints).

1
overwrite bool

Whether to overwrite existing checkpoints.

False
metadata dict[str, Any] | None

Optional metadata to save with the checkpoint.

None

Returns:

Type Description
str

Path to the saved checkpoint.

restore ¤

restore(directory: str | Path, target: Any | None = None, step: int | None = None, metadata_only: bool = False, _restore_args: Any | None = None) -> Any

Restore from a checkpoint.

Parameters:

Name Type Description Default
directory str | Path

Directory to restore from.

required
target Any | None

Optional target to restore into.

None
step int | None

Optional step to restore from (None = latest).

None
metadata_only bool

If True, only return metadata.

False
_restore_args Any | None

Reserved for Orbax interoperability.

None

Returns:

Type Description
Any

The restored object, state, or metadata.

Raises:

Type Description
ValueError

If directory doesn't exist or no checkpoints found.

get_checkpoint_steps ¤

get_checkpoint_steps(directory: str | Path) -> list[int]

Get all checkpoint steps in a directory.

latest_step ¤

latest_step(directory: str | Path) -> int | None

Get the latest checkpoint step.

list_checkpoints ¤

list_checkpoints(directory: str | Path) -> dict[int, str]

List all checkpoints in a directory.

close ¤

close() -> None

Close the checkpoint handler and release resources.

This method waits for any outstanding async operations to finish and properly cleans up the underlying checkpointer. Should be called when done with checkpointing, or use the context manager protocol.