PyTreeNode#
- class brainstate.util.PyTreeNode(*args, **kwargs)[source]#
Base class for creating JAX-compatible pytree nodes.
Subclasses of PyTreeNode are automatically converted to immutable dataclasses that work with JAX transformations.
See also
Notes
When subclassing PyTreeNode, all fields are automatically treated as part of the pytree unless explicitly marked with
pytree_node=Falseusing the field() function.Examples
>>> import jax >>> import jax.numpy as jnp >>> from brainstate.util import PyTreeNode, field >>> class Layer(PyTreeNode): ... weights: jax.Array ... bias: jax.Array ... activation: str = field(pytree_node=False, default="relu") >>> layer = Layer(weights=jnp.ones((4, 4)), bias=jnp.zeros(4)) >>> # Can be used in JAX transformations >>> def loss_fn(layer): ... return jnp.sum(layer.weights ** 2) >>> grad_fn = jax.grad(loss_fn) >>> grads = grad_fn(layer) >>> # Create modified copies with replace >>> layer2 = layer.replace(bias=jnp.ones(4))