Troubleshooting Guide¤
This guide helps you diagnose and resolve common issues with Datarax's NNX-based architecture, particularly around checkpointing, state management, and module integration.
Checkpointing Issues¤
State Serialization Errors¤
Problem: Custom objects or complex state structures fail to serialize with Orbax.
Solution: Implement proper serialization methods in your custom modules.
from datarax.core import DataraxModule
import flax.nnx as nnx
class CustomModule(DataraxModule):
def __init__(self, custom_data, name="custom"):
super().__init__(name=name)
# Store serializable data in NNX variables
self.serializable_data = nnx.Variable(self._to_serializable(custom_data))
def _to_serializable(self, data):
"""Convert custom objects to serializable format."""
if hasattr(data, 'to_dict'):
return data.to_dict()
elif isinstance(data, (list, tuple)):
return [self._to_serializable(item) for item in data]
else:
return data
def get_serializable_state(self):
"""Override to handle complex state serialization."""
state = super().get_serializable_state()
# Ensure all state is serializable
return self._clean_state_for_serialization(state)
def _clean_state_for_serialization(self, state):
"""Remove or convert non-serializable objects."""
cleaned_state = {}
for key, value in state.items():
if isinstance(value, (str, int, float, bool, type(None))):
cleaned_state[key] = value
elif hasattr(value, 'tolist'): # JAX arrays
cleaned_state[key] = value
else:
# Convert complex objects
cleaned_state[key] = str(value) # or custom conversion
return cleaned_state
Custom State Structure Compatibility¤
Problem: NNX's replace_by_pure_dict fails with custom nested state structures.
Solution: Flatten state structure or use standard NNX Variable patterns.
class CompatibleModule(DataraxModule):
def __init__(self, name="compatible"):
super().__init__(name=name)
# Use flat structure with NNX Variables
self.position = nnx.Variable(0)
self.buffer_size = nnx.Variable(100)
self.internal_seed = nnx.Variable(42)
def get_state(self):
"""Return flat state structure compatible with NNX."""
state = super().get_state()
# Avoid nested custom dictionaries
return state
def set_state(self, state):
"""Handle flat state restoration."""
super().set_state(state)
# Additional validation if needed
self._validate_state()
def _validate_state(self):
"""Validate state consistency after restoration."""
assert isinstance(self.position[...], int)
assert self.buffer_size[...] > 0
Iterator State Management¤
Problem: Iterator state becomes inconsistent after checkpointing.
# Problem: Iterator doesn't resume from correct position
iterator = pipeline.create_iterator()
# ... consume some batches ...
state = pipeline.get_state()
pipeline.set_state(state)
# Iterator might restart from beginning
Solution: Implement proper iterator state tracking.
from datarax.core.module import CheckpointableIteratorModule
class RobustIteratorModule(CheckpointableIteratorModule):
def __init__(self, data, name="robust_iterator"):
super().__init__(name=name)
self.data = nnx.Variable(data)
self.position = nnx.Variable(0)
self.epoch = nnx.Variable(0)
self._iterator = None
def create_iterator(self):
"""Create iterator that tracks position."""
self.reset_iterator()
return self
def reset_iterator(self):
"""Reset iterator to current position."""
self._iterator = iter(self.data[...][self.position[...]:])
def __next__(self):
if self._iterator is None:
self.reset_iterator()
try:
item = next(self._iterator)
self.position[...] = self.position[...] + 1
return item
except StopIteration:
self.epoch[...] = self.epoch[...] + 1
self.position[...] = 0
self.reset_iterator()
raise
def get_state(self):
"""Include iterator position in state."""
state = super().get_state()
state.update({
'iterator_position': self.position[...],
'iterator_epoch': self.epoch[...]
})
return state
def set_state(self, state):
"""Restore iterator position."""
super().set_state(state)
if 'iterator_position' in state:
self.position[...] = state['iterator_position']
if 'iterator_epoch' in state:
self.epoch[...] = state['iterator_epoch']
# Reset iterator to correct position
self.reset_iterator()
State Management Issues¤
Variable Access Patterns¤
Problem: Incorrect access to NNX Variable values causing AttributeError.
# Wrong: Accessing Variable directly
sampler.current_position = 10 # AttributeError
# Wrong: Not using proper accessor
if sampler.buffer_size > 0: # Comparing Variable object, not value
Solution: Use slice notation variable[...] or variable.get_value() for Variable access and modification.
# Correct: Access Variable values with slice notation (Flax 0.12.0+)
sampler.current_position[...] = 10
# Correct: Compare Variable values
if sampler.buffer_size[...] > 0:
process_buffer()
# Correct: Initialize Variables properly
class SamplerModule(DataraxModule):
def __init__(self, buffer_size=100):
super().__init__()
self.buffer_size = nnx.Variable(buffer_size) # Store value in Variable
self.current_position = nnx.Variable(0)
Note: The
.valueattribute is deprecated as of Flax 0.12.0. Usevariable[...]for Array variables orvariable.get_value()for other types.
PRNG State Consistency¤
Problem: Random number generation becomes inconsistent after state restoration.
# Problem: PRNG state not properly restored
rngs = nnx.Rngs(42)
sampler = ShuffleSamplerModule(rngs=rngs)
# ... use sampler ...
state = sampler.get_state()
# ... restore state ...
sampler.set_state(state)
# Random sequence might not continue correctly
Solution: Implement proper PRNG state management.
class StatefulSamplerModule(SamplerModule):
def __init__(self, seed=0, name="stateful_sampler"):
super().__init__(name=name)
self.rngs = nnx.Rngs(default=seed)
self.original_seed = nnx.Variable(seed)
def get_state(self):
"""Include PRNG state in checkpoint."""
state = super().get_state()
state.update({
'rng_state': self.rngs.default.key[...],
'rng_count': self.rngs.default.count[...],
'original_seed': self.original_seed[...]
})
return state
def set_state(self, state):
"""Restore PRNG state."""
super().set_state(state)
if 'rng_state' in state and 'rng_count' in state:
# Restore exact PRNG state
self.rngs.default.key[...] = state['rng_state']
self.rngs.default.count[...] = state['rng_count']
if 'original_seed' in state:
self.original_seed[...] = state['original_seed']
Module Integration Issues¤
Module Registration¤
Problem: Custom modules not recognized by checkpointing system.
Solution: Ensure proper inheritance and registration.
from datarax.core import SamplerModule
import flax.nnx as nnx
# Correct: Inherit from appropriate Datarax base class
class CustomSamplerModule(SamplerModule):
def __init__(self, custom_param=10, name="custom_sampler"):
# Always call super().__init__
super().__init__(name=name)
# Use NNX Variables for state
self.custom_param = nnx.Variable(custom_param)
self.internal_state = nnx.Variable({})
def sample(self, data):
"""Implement required interface."""
# Custom sampling logic
return data[::self.custom_param[...]]
Type Handler Registration¤
Problem: Complex custom types need explicit Orbax type handlers.
# For very complex custom types, register handlers
import orbax.checkpoint as ocp
class CustomTypeHandler:
def serialize(self, value):
# Convert to serializable format
return {'data': value.to_dict(), 'type': 'custom'}
def deserialize(self, serialized):
# Reconstruct from serialized format
return CustomType.from_dict(serialized['data'])
# Register handler
handler_registry = ocp.type_handlers.TypeHandlerRegistry()
handler_registry.register(CustomType, CustomTypeHandler())
Performance Issues¤
Memory Leaks in Checkpointing¤
Problem: Memory usage grows over time with frequent checkpointing.
# Problem: Accumulating checkpoint data
checkpoints = []
for i in range(1000):
state = pipeline.get_state()
checkpoints.append(state) # Memory leak!
Solution: Implement checkpoint rotation and cleanup.
class CheckpointManager:
def __init__(self, max_checkpoints=5):
self.max_checkpoints = max_checkpoints
self.checkpoints = []
def save_checkpoint(self, pipeline):
"""Save checkpoint with automatic cleanup."""
state = pipeline.get_state()
timestamp = time.time()
checkpoint = {
'state': state,
'timestamp': timestamp
}
self.checkpoints.append(checkpoint)
# Cleanup old checkpoints
while len(self.checkpoints) > self.max_checkpoints:
old_checkpoint = self.checkpoints.pop(0)
# Explicitly delete to help garbage collection
del old_checkpoint['state']
del old_checkpoint
return len(self.checkpoints) - 1 # Return checkpoint index
def restore_checkpoint(self, pipeline, index=-1):
"""Restore from specific checkpoint (default: latest)."""
if not self.checkpoints:
raise ValueError("No checkpoints available")
checkpoint = self.checkpoints[index]
pipeline.set_state(checkpoint['state'])
return checkpoint['timestamp']
Large State Serialization¤
Problem: Large pipeline states cause slow checkpointing.
# Problem: Serializing large data buffers
class LargeBufferModule(DataraxModule):
def __init__(self, buffer_size=1000000):
super().__init__()
self.large_buffer = nnx.Variable(jnp.zeros(buffer_size)) # Too large!
Solution: Implement efficient state management strategies.
class EfficientBufferModule(DataraxModule):
def __init__(self, buffer_size=1000000):
super().__init__()
# Only store essential state
self.buffer_size = nnx.Variable(buffer_size)
self.buffer_position = nnx.Variable(0)
self.buffer_seed = nnx.Variable(42)
# Don't store actual buffer data in state
self._buffer = None
def get_serializable_state(self):
"""Only serialize essential state."""
state = super().get_serializable_state()
# Remove large buffers from serialization
state_copy = {}
for key, value in state.items():
if key.startswith('_buffer'):
continue # Skip large internal buffers
state_copy[key] = value
return state_copy
def set_state(self, state):
"""Restore state and rebuild buffers."""
super().set_state(state)
# Rebuild buffer from essential state
self._rebuild_buffer()
def _rebuild_buffer(self):
"""Rebuild large buffers from essential state."""
if self._buffer is None:
self._buffer = jnp.zeros(self.buffer_size[...])
# Apply any necessary initialization based on state
Debugging Tools¤
State Inspection¤
def inspect_module_state(module, depth=0):
"""Recursively inspect module state for debugging."""
indent = " " * depth
print(f"{indent}{module.__class__.__name__}:")
if hasattr(module, 'get_state'):
state = module.get_state()
for key, value in state.items():
if isinstance(value, (int, float, str, bool)):
print(f"{indent} {key}: {value}")
elif hasattr(value, 'shape'):
print(f"{indent} {key}: shape={value.shape}, dtype={value.dtype}")
else:
print(f"{indent} {key}: {type(value)}")
# Usage
inspect_module_state(pipeline)
Checkpoint Validation¤
def validate_checkpoint_integrity(original_module, restored_module):
"""Validate that checkpoint restoration was successful."""
orig_state = original_module.get_state()
rest_state = restored_module.get_state()
# Check state keys match
assert set(orig_state.keys()) == set(rest_state.keys()), \
f"State keys mismatch: {orig_state.keys()} vs {rest_state.keys()}"
# Check state values
for key in orig_state.keys():
orig_val = orig_state[key]
rest_val = rest_state[key]
if hasattr(orig_val, 'shape'):
assert jnp.allclose(orig_val, rest_val), f"Array mismatch for key: {key}"
else:
assert orig_val == rest_val, f"Value mismatch for key: {key}"
print("✅ Checkpoint validation passed!")
# Usage
original_state = pipeline.get_state()
pipeline.set_state(original_state)
validate_checkpoint_integrity(original_pipeline, pipeline)
Best Practices Summary¤
- Always use NNX Variables: Store mutable state in
nnx.Variableobjects - Access with slice notation: Use
variable[...]for Arrays orvariable.get_value()for other types (.valueis deprecated in Flax 0.12.0+) - Implement clean serialization: Override
get_serializable_state()for complex objects - Validate after restoration: Include validation in your restoration workflow
- Manage memory: Implement checkpoint rotation for long-running processes
- Keep state flat: Avoid deeply nested custom state structures
- Test thoroughly: Create unit tests for your checkpointing functionality