SeedOrKey

Contents

SeedOrKey#

brainstate.typing.SeedOrKey#

Type for random number generator seeds or keys.

Represents values that can be used to seed random number generators or serve as PRNG keys in JAX’s random number generation system.

Components#

int

Integer seeds for random number generators.

jax.Array

JAX PRNG keys (typically created with jax.random.PRNGKey).

np.ndarray

NumPy arrays that can serve as random keys.

Examples

>>> def generate_random(key: SeedOrKey, shape: Shape) -> jax.Array:
...     '''Generate random numbers using the provided seed or key.'''
...     if isinstance(key, int):
...         key = jax.random.PRNGKey(key)
...     return jax.random.normal(key, shape)
>>>
>>> # Valid seeds/keys
>>> generate_random(42, (3, 4))                    # Integer seed
>>> generate_random(jax.random.PRNGKey(123), (5,)) # JAX PRNG key
>>> generate_random(np.array([1, 2], dtype=np.uint32), (2, 2))  # NumPy array