PyTreeNode

PyTreeNode#

class brainstate.util.PyTreeNode(*args, **kwargs)[source]#

Base class for creating JAX-compatible pytree nodes.

Subclasses of PyTreeNode are automatically converted to immutable dataclasses that work with JAX transformations.

See also

dataclass

Decorator for creating JAX-compatible dataclasses.

field

Create dataclass fields with pytree metadata.

Notes

When subclassing PyTreeNode, all fields are automatically treated as part of the pytree unless explicitly marked with pytree_node=False using the field() function.

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> from brainstate.util import PyTreeNode, field

>>> class Layer(PyTreeNode):
...     weights: jax.Array
...     bias: jax.Array
...     activation: str = field(pytree_node=False, default="relu")

>>> layer = Layer(weights=jnp.ones((4, 4)), bias=jnp.zeros(4))

>>> # Can be used in JAX transformations
>>> def loss_fn(layer):
...     return jnp.sum(layer.weights ** 2)
>>> grad_fn = jax.grad(loss_fn)
>>> grads = grad_fn(layer)

>>> # Create modified copies with replace
>>> layer2 = layer.replace(bias=jnp.ones(4))
replace(**updates)[source]#

Replace specified fields with new values.

Parameters:

**updates – Field names and their new values.

Returns:

A new instance with updated fields.

Return type:

TypeVar(TNode, bound= PyTreeNode)