NodeStates

NodeStates#

class brainstate.graph.NodeStates(_graphdef, states, metadata)[source]#

A JAX pytree wrapper that carries both a GraphDef and one or more state mappings.

Used by graph_to_tree / tree_to_graph to represent graph nodes as pure pytrees so that JAX transforms (vmap, jit, etc.) can operate on them.

replace(**updates) T#

Replace specified fields with new values.

Return type:

TypeVar(T)