context#
- class brainstate.environ.context(*, env=None, **kwargs)[source]#
Context manager for temporary environment settings.
This context manager allows you to temporarily modify environment settings within a specific scope. Settings are automatically restored when exiting the context, even if an exception occurs.
- Parameters:
env (
EnvironmentState|None) – The environment state to modify. If None, uses the global environment.**kwargs –
Environment settings to apply within the context. Common parameters include:
- precisionint or str.
Numerical precision (8, 16, 32, 64, or ‘bf16’)
- dtfloat.
Time step for numerical integration
- mode
Mode. Computation mode instance
- mode
Any custom parameters registered via register_default_behavior
- Yields:
dict – Current environment settings within the context.
- Raises:
ValueError – If attempting to set platform or host_device_count in context (these must be set globally).
TypeError – If invalid parameter types are provided.
Examples
Basic usage with precision control:
>>> import brainstate.environ as env >>> >>> # Set global precision >>> env.set(precision=32) >>> >>> # Temporarily use higher precision >>> with env.context(precision=64) as ctx: ... print(f"Precision in context: {env.get('precision')}") # 64 ... print(f"Float type: {env.dftype()}") # float64 >>> >>> print(f"Precision after context: {env.get('precision')}") # 32
Nested contexts:
>>> import brainstate.environ as env >>> >>> with env.context(dt=0.1) as ctx1: ... print(f"dt = {env.get('dt')}") # 0.1 ... ... with env.context(dt=0.01) as ctx2: ... print(f"dt = {env.get('dt')}") # 0.01 ... ... print(f"dt = {env.get('dt')}") # 0.1
Error handling in context:
>>> import brainstate.environ as env >>> >>> env.set(value=10) >>> try: ... with env.context(value=20): ... print(env.get('value')) # 20 ... raise ValueError("Something went wrong") ... except ValueError: ... pass >>> >>> print(env.get('value')) # 10 (restored)
Using custom environment:
>>> import brainstate.environ as env >>> >>> custom_env = env.EnvironmentState() >>> env.set(precision=32, env=custom_env) >>> >>> with env.context(precision=64, env=custom_env): ... print(env.get('precision', env=custom_env)) # 64 >>> >>> print(env.get('precision', env=custom_env)) # 32
Notes
Platform and host_device_count cannot be set in context
Contexts can be nested arbitrarily deep
Settings are restored in reverse order when exiting
Thread-safe: each thread maintains its own context stack
When using a custom env, JAX config is only updated if env is the global environment