field#
- class brainstate.util.field(pytree_node=True, **kwargs)[source]#
Create a dataclass field with JAX pytree metadata.
- Parameters:
pytree_node (
bool) – If True (default), this field will be treated as part of the pytree. If False, it will be treated as metadata and not be touched by JAX transformations.**kwargs – Additional arguments to pass to dataclasses.field().
- Returns:
A dataclass field with the appropriate metadata.
- Return type:
Examples
>>> import jax.numpy as jnp >>> from brainstate.util import dataclass, field >>> @dataclass ... class Model: ... weights: jnp.ndarray ... bias: jnp.ndarray ... # This field won't be affected by JAX transformations ... name: str = field(pytree_node=False, default="model")