get_key

Contents

get_key#

class brainstate.random.get_key#

Get the current global random key.

This function returns the current random key used by the global random state. The returned key represents the internal state of the JAX PRNG and can be used to restore the random state later or to create independent random number generators.

Returns:

The current JAX PRNG key as a numpy array. This is typically a 2-element uint32 array representing the internal state of the random number generator.

Example

Get and store the current random state:

>>> import brainstate
>>> brainstate.random.seed(42)
>>> current_key = brainstate.random.get_key()
>>> print(current_key.shape)
(2,)

Use the key to restore state later:

>>> # Generate some random numbers
>>> values1 = brainstate.random.rand(3)
>>> # Restore the previous state
>>> brainstate.random.set_key(current_key)
>>> values2 = brainstate.random.rand(3)
>>> # values1 and values2 will be identical

Compare keys for debugging:

>>> brainstate.random.seed(123)
>>> key1 = brainstate.random.get_key()
>>> brainstate.random.seed(123)
>>> key2 = brainstate.random.get_key()
>>> assert jax.numpy.array_equal(key1, key2)  # Same seed gives same key

Note

The returned key is a snapshot of the current state. Subsequent calls to random functions will advance the internal state, so calling get_key() again will return a different key unless the state is reset.

See also