SeedOrKey#
- brainstate.typing.SeedOrKey#
Type for random number generator seeds or keys.
Represents values that can be used to seed random number generators or serve as PRNG keys in JAX’s random number generation system.
Components#
- int
Integer seeds for random number generators.
- jax.Array
JAX PRNG keys (typically created with jax.random.PRNGKey).
- np.ndarray
NumPy arrays that can serve as random keys.
Examples
>>> def generate_random(key: SeedOrKey, shape: Shape) -> jax.Array: ... '''Generate random numbers using the provided seed or key.''' ... if isinstance(key, int): ... key = jax.random.PRNGKey(key) ... return jax.random.normal(key, shape) >>> >>> # Valid seeds/keys >>> generate_random(42, (3, 4)) # Integer seed >>> generate_random(jax.random.PRNGKey(123), (5,)) # JAX PRNG key >>> generate_random(np.array([1, 2], dtype=np.uint32), (2, 2)) # NumPy array