dataclass

Contents

dataclass#

class brainstate.util.dataclass(cls, **kwargs)[source]#

Create a dataclass that works with JAX transformations.

This decorator creates immutable dataclasses that can be used safely with JAX transformations like jit, grad, vmap, etc. The created class will be registered as a JAX pytree node.

Parameters:
  • cls (type[TypeVar(T)]) – The class to decorate.

  • **kwargs – Additional arguments for dataclasses.dataclass(). If ‘frozen’ is not specified, it defaults to True.

Returns:

The decorated class as an immutable JAX-compatible dataclass.

Return type:

type[TypeVar(T)]

See also

PyTreeNode

Base class for creating JAX-compatible pytree nodes.

field

Create dataclass fields with pytree metadata.

Notes

The decorated class will be frozen (immutable) by default to ensure compatibility with JAX’s functional programming paradigm.

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> from brainstate.util import dataclass, field

>>> @dataclass
... class Model:
...     weights: jax.Array
...     bias: jax.Array
...     name: str = field(pytree_node=False, default="model")

>>> model = Model(weights=jnp.ones((3, 3)), bias=jnp.zeros(3))

>>> # JAX transformations will only apply to weights and bias, not name
>>> grad_fn = jax.grad(lambda m: jnp.sum(m.weights))
>>> grads = grad_fn(model)

>>> # Use replace to create modified copies
>>> model2 = model.replace(weights=jnp.ones((3, 3)) * 2)