seed_context

Contents

seed_context#

class brainstate.random.seed_context(seed_or_key)#

Context manager for temporary random seed changes with automatic restoration.

This context manager temporarily changes the global random seed for the duration of the block, then automatically restores the previous random state when exiting. It’s ideal for ensuring reproducible computations in specific code sections without permanently affecting the global random state.

Parameters:

seed_or_key (int | Array | ndarray) – The temporary seed or key to use within the context. Can be: - int: An integer seed for reproducible sequences - JAX PRNG key: A JAX random key array The seed affects both JAX and NumPy random states during the context.

Yields:

None – The context manager doesn’t yield any value, but provides a controlled random environment for the enclosed code block.

Example

Reproducible computations without affecting global state:

>>> import brainstate
>>> # Global state remains unaffected
>>> global_values1 = brainstate.random.rand(2)
>>>
>>> with brainstate.random.seed_context(42):
...     temp_values1 = brainstate.random.rand(2)
...     print(f"First run: {temp_values1}")
[0.95598125 0.4032725 ]
>>>
>>> with brainstate.random.seed_context(42):
...     temp_values2 = brainstate.random.rand(2)
...     print(f"Second run: {temp_values2}")
[0.95598125 0.4032725 ]
>>>
>>> # Values are identical within context
>>> assert np.allclose(temp_values1, temp_values2)
>>>
>>> # Global state continues from where it left off
>>> global_values2 = brainstate.random.rand(2)

Nested contexts for complex scenarios:

>>> with brainstate.random.seed_context(123):
...     outer_values = brainstate.random.rand(2)
...     with brainstate.random.seed_context(456):
...         inner_values = brainstate.random.rand(2)
...     # Outer context is restored here
...     outer_values2 = brainstate.random.rand(2)

Exception safety - state is restored even on errors:

>>> try:
...     with brainstate.random.seed_context(789):
...         some_values = brainstate.random.rand(3)
...         raise ValueError("Something went wrong")
... except ValueError:
...     pass
>>> # Random state is properly restored

Testing reproducible algorithms:

>>> def test_algorithm():
...     with brainstate.random.seed_context(42):
...         data = brainstate.random.normal(size=(100,))
...         return data.mean()
>>>
>>> result1 = test_algorithm()
>>> result2 = test_algorithm()
>>> assert result1 == result2  # Always same result

Note

  • The context manager saves and restores the complete JAX random state

  • NumPy’s random state is also temporarily modified during the context

  • Nested contexts work correctly - each level restores its own state

  • Exception safety is guaranteed - random state is restored even if exceptions occur within the context

  • This is more convenient than manually saving/restoring state with get_key() and set_key()

See also

  • seed(): Permanently set the global random seed

  • get_key(): Get the current random key for manual state management

  • set_key(): Set the random key for manual state management

  • clone_rng(): Create independent random states