PRNG (Random Number Generation)¤
JAX random number generation utilities.
See Also¤
- Utils Overview - All utilities
- Shuffle Sampler - Random shuffling
- Probabilistic Operator - Random augmentation
- NNX Best Practices - PRNG patterns
datarax.utils.prng ¤
PRNG handling utilities for Datarax.
This module provides utilities for managing random number generation in Datarax, focusing on compatibility with Flax NNX and JAX's functional paradigm.
Key Utilities
create_rngs: Creates flax.nnx.Rngs objects with multiple named streams
derived from a single master seed.
Usage Note
Internal Datarax code should exclusively use flax.nnx.Rngs for randomness.
When interfacing with external libraries that require raw JAX PRNG keys, call
the specific stream (e.g., rngs.params()) to generate a unique key.
DEFAULT_RNG_STREAMS
module-attribute
¤
create_rngs ¤
Create an Rngs object with the specified streams.
This is a convenience function that creates multiple RNG streams from a single seed. For simple cases, you can use nnx.Rngs directly: rngs = nnx.Rngs(42) # Single default stream rngs = nnx.Rngs(params=0, dropout=1) # Multiple streams
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int | None
|
Optional seed for PRNG. Defaults to 0. |
None
|
streams
|
list[str] | None
|
List of stream names. Defaults to DEFAULT_RNG_STREAMS. |
None
|
Returns:
| Type | Description |
|---|---|
Rngs
|
An nnx.Rngs object with keys for each stream. |