brainstate.graph moduel#

Most of these APIs are adapted from Flax (google/flax). It enables the structure-preserving State retrieval and manipulatio in the brainstate.

Graph Node#

Node

Base class for all graph nodes in the BrainState framework.

Graph Operation#

pop_states

Pop one or more State types from the graph node, removing them from the node.

nodes

Return all graph nodes, optionally filtered and limited by hierarchy depth.

states

Return all State objects from a graph node, optionally filtered.

treefy_states

Return the treefy state mapping of a graph node, optionally filtered.

update_states

Update the graph node in-place with the given state dict(s).

flatten

Flatten a graph node into a (graph_def, state_mapping) pair.

unflatten

Unflatten a graph_def + state_mapping back into a node.

treefy_split

Split a graph node into a GraphDef and one or more state NestedDicts.

treefy_merge

Reconstruct a node from its GraphDef and one or more state NestedDicts.

iter_leaf

Iterate over all leaf values in the graph node (non-node values).

iter_node

Iterate over all graph nodes within the given node.

clone

Create a deep copy of the given graph node.

graphdef

Return the GraphDef of the given graph node.

Context Management#

Context managers for handling complex state updates during graph transformations. These utilities enable splitting and merging graph states in a thread-safe manner.

split_context

Context manager for splitting multiple graph nodes sharing a reference index.

merge_context

Context manager for merging multiple graph nodes sharing a reference index.

Graph Conversion#

Utilities for converting between graph and tree representations, enabling flexible manipulation of nested module structures.

graph_to_tree

Convert a pytree that may contain graph nodes into a pure pytree of NodeStates.

tree_to_graph

Convert a pytree of NodeStates back into graph nodes.

NodeStates

A JAX pytree wrapper that carries both a GraphDef and one or more state mappings.

Graph Definition Classes#

Core classes for representing graph structure, node definitions, and references. These classes provide the foundation for graph operations and state management.

GraphDef

Base class representing the static graph structure of a node.

NodeDef

Graph structure of a node, containing all static information for reconstruction.

NodeRef

A reference to an already-seen node in the graph (used for shared/circular refs).

RefMap

A mapping that uses object identity (id) as the hash key.

register_graph_node_type

Register a custom graph node type with the graph system.