dftype

Contents

dftype#

class brainstate.environ.dftype(*, env=None)[source]#

Get the default floating-point data type.

This function returns the appropriate floating-point type based on the current precision setting, allowing dynamic type selection.

Parameters:

env (EnvironmentState | None) – The environment state to query. If None, uses the global environment.

Returns:

Default floating-point data type.

Return type:

str | type[Any] | dtype | SupportsDType

Examples

>>> import brainstate.environ as env
>>> import jax.numpy as jnp
>>>
>>> # With 32-bit precision
>>> env.set(precision=32)
>>> x = jnp.zeros(10, dtype=env.dftype())
>>> print(x.dtype)  # float32
>>>
>>> # With 64-bit precision
>>> with env.context(precision=64):
...     y = jnp.ones(5, dtype=env.dftype())
...     print(y.dtype)  # float64
>>>
>>> # With bfloat16
>>> env.set(precision='bf16')
>>> z = jnp.array([1, 2, 3], dtype=env.dftype())
>>> print(z.dtype)  # bfloat16

Using custom environment:

>>> import brainstate.environ as env
>>>
>>> custom_env = env.EnvironmentState()
>>> env.set(precision=64, env=custom_env)
>>> print(env.dftype(env=custom_env))  # float64

See also

ditype

Default integer type

dutype

Default unsigned integer type

dctype

Default complex type