split_key#
- class brainstate.random.split_key(n=None, backup=False)#
Create new random key(s) from the current seed.
This function generates one or more independent random keys by splitting the current global random state. It follows JAX’s random paradigm, ensuring that each split key produces statistically independent random sequences.
- Parameters:
n (
int) – The number of keys to generate. If None, returns a single key. If an integer, returns an array of n keys.backup (
bool) – Whether to backup the current key before splitting. This allows restoration of the original state usingrestore_key().
- Returns:
A single JAX PRNG key. If n is an integer: An array of n independent JAX PRNG keys.
- Return type:
If n is None
Example
Generate a single key:
>>> import brainstate >>> brainstate.random.seed(42) >>> key = brainstate.random.split_key() >>> print(key.shape) (2,)
Generate multiple keys for parallel computation:
>>> keys = brainstate.random.split_key(4) >>> print(keys.shape) (4, 2)
Use with backup for state restoration:
>>> original_key = brainstate.random.get_key() >>> keys = brainstate.random.split_key(2, backup=True) >>> brainstate.random.restore_key() >>> assert np.array_equal(brainstate.random.get_key(), original_key)
Note
This function advances the global random state. Each call produces different keys unless the state is reset.
See also
split_keys(): Convenience function for multiple keysseed(): Set the random seedrestore_key(): Restore backed up key