split_keys

Contents

split_keys#

class brainstate.random.split_keys(n, backup=False)#

Create multiple independent random keys from the current seed.

This is a convenience function that generates exactly n independent random keys by splitting the current global random state. It’s commonly used internally by parallel computation functions like pmap and vmap to ensure that each parallel thread gets a unique random key.

Parameters:
  • n (int) – The number of independent keys to generate. Must be a positive integer.

  • backup (bool) – Whether to backup the current key before splitting. If True, the original key can be restored using restore_key().

Returns:

An array of n independent JAX PRNG keys with shape (n, 2).

Raises:

ValueError – If n is not a positive integer.

Example

Generate keys for parallel computation:

>>> import brainstate
>>> brainstate.random.seed(42)
>>> keys = brainstate.random.split_keys(4)
>>> print(keys.shape)
(4, 2)

Use with vmap for parallel random number generation:

>>> import jax
>>> keys = brainstate.random.split_keys(8)
>>> @jax.vmap
... def generate_random(key):
...     return jax.random.normal(key, (10,))
>>> parallel_randoms = generate_random(keys)
>>> print(parallel_randoms.shape)
(8, 10)

Use with backup for state preservation:

>>> original_state = brainstate.random.get_key()
>>> keys = brainstate.random.split_keys(3, backup=True)
>>> # ... use keys for computation ...
>>> brainstate.random.restore_key()  # Restore original state

Note

This function is equivalent to calling split_key() with n as an argument. It’s provided as a convenience function with a more explicit name for clarity.

See also