field

Contents

field#

class brainstate.util.field(pytree_node=True, **kwargs)[source]#

Create a dataclass field with JAX pytree metadata.

Parameters:
  • pytree_node (bool) – If True (default), this field will be treated as part of the pytree. If False, it will be treated as metadata and not be touched by JAX transformations.

  • **kwargs – Additional arguments to pass to dataclasses.field().

Returns:

A dataclass field with the appropriate metadata.

Return type:

Field

Examples

>>> import jax.numpy as jnp
>>> from brainstate.util import dataclass, field

>>> @dataclass
... class Model:
...     weights: jnp.ndarray
...     bias: jnp.ndarray
...     # This field won't be affected by JAX transformations
...     name: str = field(pytree_node=False, default="model")