Skip to content

PRNG (Random Number Generation)¤

JAX random number generation utilities.

See Also¤


datarax.utils.prng ¤

PRNG handling utilities for Datarax.

This module provides utilities for managing random number generation in Datarax, focusing on compatibility with Flax NNX and JAX's functional paradigm.

Key Utilities

create_rngs: Creates flax.nnx.Rngs objects with multiple named streams derived from a single master seed.

Usage Note

Internal Datarax code should exclusively use flax.nnx.Rngs for randomness. When interfacing with external libraries that require raw JAX PRNG keys, call the specific stream (e.g., rngs.params()) to generate a unique key.

logger module-attribute ¤

logger = getLogger(__name__)

DEFAULT_RNG_STREAMS module-attribute ¤

DEFAULT_RNG_STREAMS = ['augment', 'dropout', 'params', 'shuffling', 'default']

create_rngs ¤

create_rngs(seed: int | None = None, streams: list[str] | None = None) -> Rngs

Create an Rngs object with the specified streams.

This is a convenience function that creates multiple RNG streams from a single seed. For simple cases, you can use nnx.Rngs directly: rngs = nnx.Rngs(42) # Single default stream rngs = nnx.Rngs(params=0, dropout=1) # Multiple streams

Parameters:

Name Type Description Default
seed int | None

Optional seed for PRNG. Defaults to 0.

None
streams list[str] | None

List of stream names. Defaults to DEFAULT_RNG_STREAMS.

None

Returns:

Type Description
Rngs

An nnx.Rngs object with keys for each stream.