dataclass#
- class brainstate.util.dataclass(cls, **kwargs)[source]#
Create a dataclass that works with JAX transformations.
This decorator creates immutable dataclasses that can be used safely with JAX transformations like jit, grad, vmap, etc. The created class will be registered as a JAX pytree node.
- Parameters:
- Returns:
The decorated class as an immutable JAX-compatible dataclass.
- Return type:
See also
PyTreeNodeBase class for creating JAX-compatible pytree nodes.
fieldCreate dataclass fields with pytree metadata.
Notes
The decorated class will be frozen (immutable) by default to ensure compatibility with JAX’s functional programming paradigm.
Examples
>>> import jax >>> import jax.numpy as jnp >>> from brainstate.util import dataclass, field >>> @dataclass ... class Model: ... weights: jax.Array ... bias: jax.Array ... name: str = field(pytree_node=False, default="model") >>> model = Model(weights=jnp.ones((3, 3)), bias=jnp.zeros(3)) >>> # JAX transformations will only apply to weights and bias, not name >>> grad_fn = jax.grad(lambda m: jnp.sum(m.weights)) >>> grads = grad_fn(model) >>> # Use replace to create modified copies >>> model2 = model.replace(weights=jnp.ones((3, 3)) * 2)