context

Contents

context#

class brainstate.environ.context(*, env=None, **kwargs)[source]#

Context manager for temporary environment settings.

This context manager allows you to temporarily modify environment settings within a specific scope. Settings are automatically restored when exiting the context, even if an exception occurs.

Parameters:
  • env (EnvironmentState | None) – The environment state to modify. If None, uses the global environment.

  • **kwargs

    Environment settings to apply within the context. Common parameters include:

    • precisionint or str.

      Numerical precision (8, 16, 32, 64, or ‘bf16’)

    • dtfloat.

      Time step for numerical integration

    • modeMode.

      Computation mode instance

    • Any custom parameters registered via register_default_behavior

Yields:

dict – Current environment settings within the context.

Raises:
  • ValueError – If attempting to set platform or host_device_count in context (these must be set globally).

  • TypeError – If invalid parameter types are provided.

Examples

Return type:

ContextManager[Dict[str, Any], bool | None]

Basic usage with precision control:

>>> import brainstate.environ as env
>>>
>>> # Set global precision
>>> env.set(precision=32)
>>>
>>> # Temporarily use higher precision
>>> with env.context(precision=64) as ctx:
...     print(f"Precision in context: {env.get('precision')}")  # 64
...     print(f"Float type: {env.dftype()}")  # float64
>>>
>>> print(f"Precision after context: {env.get('precision')}")  # 32

Nested contexts:

>>> import brainstate.environ as env
>>>
>>> with env.context(dt=0.1) as ctx1:
...     print(f"dt = {env.get('dt')}")  # 0.1
...
...     with env.context(dt=0.01) as ctx2:
...         print(f"dt = {env.get('dt')}")  # 0.01
...
...     print(f"dt = {env.get('dt')}")  # 0.1

Error handling in context:

>>> import brainstate.environ as env
>>>
>>> env.set(value=10)
>>> try:
...     with env.context(value=20):
...         print(env.get('value'))  # 20
...         raise ValueError("Something went wrong")
... except ValueError:
...     pass
>>>
>>> print(env.get('value'))  # 10 (restored)

Using custom environment:

>>> import brainstate.environ as env
>>>
>>> custom_env = env.EnvironmentState()
>>> env.set(precision=32, env=custom_env)
>>>
>>> with env.context(precision=64, env=custom_env):
...     print(env.get('precision', env=custom_env))  # 64
>>>
>>> print(env.get('precision', env=custom_env))  # 32

Notes

  • Platform and host_device_count cannot be set in context

  • Contexts can be nested arbitrarily deep

  • Settings are restored in reverse order when exiting

  • Thread-safe: each thread maintains its own context stack

  • When using a custom env, JAX config is only updated if env is the global environment