set_precision#
- class brainstate.environ.set_precision(precision, *, env=None)[source]#
Set the global numerical precision.
- Parameters:
precision (
int|str) – Precision to use (8, 16, 32, 64, or ‘bf16’).env (
EnvironmentState|None) – The environment state to modify. If None, uses the global environment.
- Raises:
ValueError – If precision is not supported.
- Return type:
Examples
>>> import brainstate.environ as env >>> import jax.numpy as jnp >>> >>> # Set to 64-bit precision >>> env.set_precision(64) >>> >>> # Arrays will use float64 by default >>> x = jnp.array([1.0, 2.0, 3.0]) >>> print(x.dtype) # float64 >>> >>> # Set to bfloat16 for efficiency >>> env.set_precision('bf16')
Using custom environment:
>>> import brainstate.environ as env >>> >>> custom_env = env.EnvironmentState() >>> env.set_precision(64, env=custom_env) >>> print(env.get_precision(env=custom_env)) # 64
Notes
When using a custom env, JAX config is only updated if env is the global environment.