split_keys#
- class brainstate.random.split_keys(n, backup=False)#
Create multiple independent random keys from the current seed.
This is a convenience function that generates exactly n independent random keys by splitting the current global random state. It’s commonly used internally by parallel computation functions like pmap and vmap to ensure that each parallel thread gets a unique random key.
- Parameters:
n (
int) – The number of independent keys to generate. Must be a positive integer.backup (
bool) – Whether to backup the current key before splitting. If True, the original key can be restored usingrestore_key().
- Returns:
An array of n independent JAX PRNG keys with shape (n, 2).
- Raises:
ValueError – If n is not a positive integer.
Example
Generate keys for parallel computation:
>>> import brainstate >>> brainstate.random.seed(42) >>> keys = brainstate.random.split_keys(4) >>> print(keys.shape) (4, 2)
Use with vmap for parallel random number generation:
>>> import jax >>> keys = brainstate.random.split_keys(8) >>> @jax.vmap ... def generate_random(key): ... return jax.random.normal(key, (10,)) >>> parallel_randoms = generate_random(keys) >>> print(parallel_randoms.shape) (8, 10)
Use with backup for state preservation:
>>> original_state = brainstate.random.get_key() >>> keys = brainstate.random.split_keys(3, backup=True) >>> # ... use keys for computation ... >>> brainstate.random.restore_key() # Restore original state
Note
This function is equivalent to calling
split_key()with n as an argument. It’s provided as a convenience function with a more explicit name for clarity.See also
split_key(): More general key splitting functionself_assign_multi_keys(): Assign multiple keys to global stateseed_context(): Temporary seed changes