A JAX pytree wrapper that carries both a GraphDef and one or more state mappings.
Used by graph_to_tree / tree_to_graph to represent graph nodes as
pure pytrees so that JAX transforms (vmap, jit, etc.) can operate on them.
-
replace(**updates) → T
Replace specified fields with new values.
- Return type:
TypeVar(T)