split_key

Contents

split_key#

class brainstate.random.split_key(n=None, backup=False)#

Create new random key(s) from the current seed.

This function generates one or more independent random keys by splitting the current global random state. It follows JAX’s random paradigm, ensuring that each split key produces statistically independent random sequences.

Parameters:
  • n (int) – The number of keys to generate. If None, returns a single key. If an integer, returns an array of n keys.

  • backup (bool) – Whether to backup the current key before splitting. This allows restoration of the original state using restore_key().

Returns:

A single JAX PRNG key. If n is an integer: An array of n independent JAX PRNG keys.

Return type:

If n is None

Example

Generate a single key:

>>> import brainstate
>>> brainstate.random.seed(42)
>>> key = brainstate.random.split_key()
>>> print(key.shape)
(2,)

Generate multiple keys for parallel computation:

>>> keys = brainstate.random.split_key(4)
>>> print(keys.shape)
(4, 2)

Use with backup for state restoration:

>>> original_key = brainstate.random.get_key()
>>> keys = brainstate.random.split_key(2, backup=True)
>>> brainstate.random.restore_key()
>>> assert np.array_equal(brainstate.random.get_key(), original_key)

Note

This function advances the global random state. Each call produces different keys unless the state is reset.

See also