dftype#
- class brainstate.environ.dftype(*, env=None)[source]#
Get the default floating-point data type.
This function returns the appropriate floating-point type based on the current precision setting, allowing dynamic type selection.
- Parameters:
env (
EnvironmentState|None) – The environment state to query. If None, uses the global environment.- Returns:
Default floating-point data type.
- Return type:
Examples
>>> import brainstate.environ as env >>> import jax.numpy as jnp >>> >>> # With 32-bit precision >>> env.set(precision=32) >>> x = jnp.zeros(10, dtype=env.dftype()) >>> print(x.dtype) # float32 >>> >>> # With 64-bit precision >>> with env.context(precision=64): ... y = jnp.ones(5, dtype=env.dftype()) ... print(y.dtype) # float64 >>> >>> # With bfloat16 >>> env.set(precision='bf16') >>> z = jnp.array([1, 2, 3], dtype=env.dftype()) >>> print(z.dtype) # bfloat16
Using custom environment:
>>> import brainstate.environ as env >>> >>> custom_env = env.EnvironmentState() >>> env.set(precision=64, env=custom_env) >>> print(env.dftype(env=custom_env)) # float64