set_precision

Contents

set_precision#

class brainstate.environ.set_precision(precision, *, env=None)[source]#

Set the global numerical precision.

Parameters:
  • precision (int | str) – Precision to use (8, 16, 32, 64, or ‘bf16’).

  • env (EnvironmentState | None) – The environment state to modify. If None, uses the global environment.

Raises:

ValueError – If precision is not supported.

Return type:

None

Examples

>>> import brainstate.environ as env
>>> import jax.numpy as jnp
>>>
>>> # Set to 64-bit precision
>>> env.set_precision(64)
>>>
>>> # Arrays will use float64 by default
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> print(x.dtype)  # float64
>>>
>>> # Set to bfloat16 for efficiency
>>> env.set_precision('bf16')

Using custom environment:

>>> import brainstate.environ as env
>>>
>>> custom_env = env.EnvironmentState()
>>> env.set_precision(64, env=custom_env)
>>> print(env.get_precision(env=custom_env))  # 64

Notes

When using a custom env, JAX config is only updated if env is the global environment.