Best Practices for Flax NNX Module Design¤
This guide outlines best practices for designing and implementing Flax NNX-based modules in Datarax. Following these practices ensures modules are consistent, maintainable, and compatible with JAX transformations.
Core Principles¤
Flax NNX is built on fundamental principles:
- Python Reference Semantics: NNX modules are regular Python objects with reference sharing and mutability
- Eager Initialization: All parameters are created at module instantiation time
- Explicit State Management: State is managed through Variables with clear ownership
- Simplified Transforms: Lifted transforms that work directly with NNX objects
Module Structure¤
Standard Module Pattern¤
from flax import nnx
import jax.numpy as jnp
class MyModule(nnx.Module):
"""Module description.
More detailed explanation of what the module does.
"""
def __init__(
self,
features_in: int,
features_out: int,
*, # Keyword-only arguments after this
rngs: nnx.Rngs,
):
"""Initialize the module.
Args:
features_in: Input feature dimension
features_out: Output feature dimension
rngs: RNG container for parameter initialization
"""
# Parameters are created eagerly (no lazy initialization)
self.kernel = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (features_in, features_out))
)
self.bias = nnx.Param(jnp.zeros((features_out,)))
# Store static configuration (not wrapped in Variable)
self.features_in = features_in
self.features_out = features_out
def __call__(self, x: jax.Array) -> jax.Array:
"""Process input data.
Args:
x: Input array of shape (batch, features_in)
Returns:
Output array of shape (batch, features_out)
"""
# Variables support arithmetic operators directly
return x @ self.kernel + self.bias
Key Points¤
- Keyword-only RNGs: Use
*, rngs: nnx.Rngsto enforce keyword-only syntax - Eager initialization: Parameters are created immediately in
__init__ - Direct state ownership: Modules hold their parameters as instance attributes
- Static vs dynamic: Store hyperparameters (shapes, counts) as static attributes; only use Variables for what needs to be traced by JAX
Variable Types¤
Use the appropriate NNX variable type for each state:
| Variable Type | Purpose | Example |
|---|---|---|
nnx.Param |
Trainable parameters (weights, biases) | self.kernel = nnx.Param(weights) |
nnx.BatchStat |
Batch statistics (running means) | self.running_mean = nnx.BatchStat(jnp.zeros(dim)) |
nnx.Variable |
General state variables | self.count = nnx.Variable(jnp.array(0)) |
| Custom | Domain-specific state | class Count(nnx.Variable): pass |
class StatefulModule(nnx.Module):
def __init__(self, dim: int, *, rngs: nnx.Rngs):
# Trainable parameter
self.weight = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (dim, dim)))
# Batch statistics
self.running_mean = nnx.BatchStat(jnp.zeros(dim))
# General state
self.step_count = nnx.Variable(jnp.array(0))
Accessing Variable Values¤
NNX Variables implement numeric operators, so you can use them directly in expressions:
def __call__(self, x):
# Both forms work - Variables support arithmetic operators
y = x @ self.weight # Direct use (preferred for readability)
y = x @ self.weight[...] # Explicit slice access (for Array variables)
return y
Note: The
.valueattribute is deprecated as of Flax 0.12.0. Usevariable[...]for Array variables orvariable.get_value()for other types.
Updating State¤
In-place mutation of Variable values is supported and is the standard pattern:
def __call__(self, x):
# In-place update using slice notation (recommended for Flax 0.12.0+)
self.step_count[...] = self.step_count[...] + 1
# For arrays, use JAX update patterns
self.running_mean[...] = 0.9 * self.running_mean[...] + 0.1 * x.mean()
return self.process(x)
RNG Handling¤
Constructor Pattern¤
Always use keyword-only rngs parameter:
def __init__(
self,
dim: int,
*, # Keyword-only parameters after this
rngs: nnx.Rngs,
):
# Access specific RNG streams
self.kernel = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (dim, dim))
)
Named RNG Streams¤
Use named streams for different purposes:
# Create RNGs with named streams
rngs = nnx.Rngs(params=0, dropout=1, augment=2)
# Or with a single default seed
rngs = nnx.Rngs(0) # Creates 'default' stream
Using RNGs in Forward Pass¤
For modules that need randomness during forward pass (like Dropout):
class StochasticModule(nnx.Module):
def __init__(self, rate: float = 0.1, *, rngs: nnx.Rngs):
self.rate = rate
self.rngs = rngs
def __call__(self, x):
# Get a key from the RNG stream
key = self.rngs.dropout()
mask = jax.random.bernoulli(key, 1 - self.rate, x.shape)
return jnp.where(mask, x / (1 - self.rate), 0)
Training and Evaluation Modes¤
Use the built-in train() and eval() methods:
model = MLP([784, 256, 10], rngs=nnx.Rngs(0))
# Training mode: Dropout enabled, BatchNorm uses batch statistics
model.train()
y_train = model(x_batch)
# Evaluation mode: Dropout disabled, BatchNorm uses running statistics
model.eval()
y_eval = model(x_batch)
Module Composition¤
Nested Modules¤
Compose modules by assigning them as attributes:
class MLP(nnx.Module):
def __init__(self, features: list[int], *, rngs: nnx.Rngs):
# Use nnx.List for module collections (required in Flax 0.12.0+)
self.layers = nnx.List([
nnx.Linear(features[i], features[i + 1], rngs=rngs)
for i in range(len(features) - 1)
])
self.dropout = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, x):
for i, layer in enumerate(self.layers[:-1]):
x = nnx.relu(layer(x))
x = self.dropout(x)
return self.layers[-1](x)
Important: As of Flax 0.12.0, plain Python lists containing modules must be wrapped with
nnx.List(). This ensures proper parameter tracking and JAX transformation compatibility.
Module Collections¤
Use nnx.Dict and nnx.List for dynamic module collections:
class MultiHeadModel(nnx.Module):
def __init__(self, num_heads: int, dim: int, *, rngs: nnx.Rngs):
# Use nnx.Dict for named collections
self.heads = nnx.Dict({
f'head_{i}': nnx.Linear(dim, dim, rngs=rngs)
for i in range(num_heads)
})
# Use nnx.List for indexed collections
self.layers = nnx.List([
nnx.Linear(dim, dim, rngs=rngs)
for _ in range(3)
])
Why use nnx.Dict/List? Plain Python
dict/listwon't properly track parameters, state, or integrate with NNX transformations.
JAX Transformations¤
Using nnx.jit¤
For automatic state management:
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
Using the Functional API¤
For pure JAX transforms, use split/merge:
graphdef, state = nnx.split(model)
@jax.jit
def forward(graphdef, state, x):
model = nnx.merge(graphdef, state)
y = model(x)
_, state = nnx.split(model)
return y, state
y, state = forward(graphdef, state, x)
nnx.update(model, state) # Propagate state back
StateAxes for Fine-Grained Control¤
Control how different state types are transformed:
# Vectorize parameters, broadcast counts
state_axes = nnx.StateAxes({nnx.Param: 0, Count: None})
y = nnx.vmap(forward_fn, in_axes=(state_axes, 0))(model, x_batch)
Performance Best Practices¤
Vectorization¤
Use vectorized operations, not Python loops:
# Good - vectorized
def transform_batch(self, batch):
return batch * self.scale
# Avoid - loop over batch
def transform_batch(self, batch):
results = []
for item in batch:
results.append(item * self.scale)
return jnp.stack(results)
Avoid Python Control Flow on Traced Values¤
# Bad - JAX cannot trace this
def process(self, x):
if x.sum() > 0: # Depends on traced value
return x * 2
return x
# Good - Use JAX control flow
def process(self, x):
return jax.lax.cond(
x.sum() > 0,
lambda x: x * 2,
lambda x: x,
x
)
Module Introspection¤
Use functional forms for iteration (instance methods are deprecated):
# Iterate over direct children
for name, child in nnx.iter_children(model):
print(f"Child: {name} -> {type(child)}")
# Iterate over all modules in tree
for path, module in nnx.iter_modules(model):
print(f"Module at {path}: {type(module)}")
# Iterate over entire graph including variables
for path, value in nnx.iter_graph(model):
if isinstance(value, nnx.Variable):
print(f"Variable at {path}: {value[...].shape}")
Visualization¤
# Rich visualization (in Jupyter/Colab)
nnx.display(model)
# Tabular summary
print(nnx.tabulate(model)(jnp.ones((1, 784))))