get_key#
- class brainstate.random.get_key#
Get the current global random key.
This function returns the current random key used by the global random state. The returned key represents the internal state of the JAX PRNG and can be used to restore the random state later or to create independent random number generators.
- Returns:
The current JAX PRNG key as a numpy array. This is typically a 2-element uint32 array representing the internal state of the random number generator.
Example
Get and store the current random state:
>>> import brainstate >>> brainstate.random.seed(42) >>> current_key = brainstate.random.get_key() >>> print(current_key.shape) (2,)
Use the key to restore state later:
>>> # Generate some random numbers >>> values1 = brainstate.random.rand(3) >>> # Restore the previous state >>> brainstate.random.set_key(current_key) >>> values2 = brainstate.random.rand(3) >>> # values1 and values2 will be identical
Compare keys for debugging:
>>> brainstate.random.seed(123) >>> key1 = brainstate.random.get_key() >>> brainstate.random.seed(123) >>> key2 = brainstate.random.get_key() >>> assert jax.numpy.array_equal(key1, key2) # Same seed gives same key
Note
The returned key is a snapshot of the current state. Subsequent calls to random functions will advance the internal state, so calling get_key() again will return a different key unless the state is reset.
See also
set_key(): Set a new random keyseed(): Set the random seed (also affects NumPy)split_key(): Create new keys from current stateseed_context(): Temporary seed changes with automatic restoration