NestedDict#
- class brainstate.util.NestedDict#
A pytree-like nested mapping structure for organizing hierarchical data.
This class represents a nested mapping from strings or integers to leaves, where valid leaf types include
State,jax.Array,numpy.ndarray, or nestedNestedDictandFlattedDictstructures.NestedDictis a JAX pytree and can be used with JAX transformations. It provides methods for flattening toFlattedDict, splitting/filtering based on predicates, and merging multiple nested structures.Example
>>> from brainstate.util import NestedDict >>> state = NestedDict({ ... 'layer1': {'weight': jnp.ones((3, 3)), 'bias': jnp.zeros(3)}, ... 'layer2': {'weight': jnp.ones((3, 1))} ... }) >>> flat = state.to_flat() >>> print(flat) FlattedDict({('layer1', 'weight'): ..., ('layer1', 'bias'): ..., ...})
See also
FlattedDict: The flattened counterpart with tuple keys.flat_mapping(): Function to flatten a nested mapping.nest_mapping(): Function to unflatten a flat mapping.- filter(*filters)[source]#
Filter a
NestedDictinto one or moreNestedDict’s. The user must pass at least one :class:`Filter (i.e.State). This method is similar tosplit(), except the filters can be non-exhaustive.- Parameters:
first – The first filter
*filters (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – The optional, additional filters to group the state into mutually exclusive substates.
- Return type:
NestedDict|Tuple[NestedDict,...]- Returns:
One or more
Statesequal to the number of filters passed.
- classmethod from_flat(flat_dict)[source]#
Create a
NestedDictfrom a flat mapping.
- static merge(*states)[source]#
The inverse of
split().mergetakes one or morePrettyDict’s and creates a newPrettyDict.- Parameters:
*states – Additional
PrettyDictobjects.- Return type:
- Returns:
The merged
PrettyDict.
- replace_by_pure_dict(pure_dict, replace_fn=None)[source]#
Replace values in this NestedDict using a pure dictionary.
This method updates the values in this NestedDict with values from a standard Python dictionary. For
Stateobjects with areplacemethod, the replace method is called; otherwise, values are directly assigned.- Parameters:
pure_dict (
Dict[str,Any]) – A pure dictionary with matching structure containing new values.replace_fn (
Callable[[TypeVar(V),Any],TypeVar(V)] |None) – Optional custom function to replace values. Takes(old_value, new_value)and returns the updated value. Defaults to callingreplace()method if available, otherwise direct assignment.
- Raises:
ValueError – If a key in
pure_dictis not found in this NestedDict.- Return type:
Example
>>> from brainstate._state import State >>> nested = NestedDict({'a': State(1), 'b': 2}) >>> nested.replace_by_pure_dict({'a': 10, 'b': 20}) >>> nested['a'].value 10
- split(*filters)[source]#
Split a
NestedDictinto one or moreNestedDict’s. The user must pass at least one :class:`Filter (i.e.State), and the filters must be exhaustive (i.e. they must cover allStatetypes in theNestedDict).Example usage:
>>> import brainstate as brainstate >>> class Model(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.batchnorm = brainstate.nn.BatchNorm1d([10, 3]) ... self.linear = brainstate.nn.Linear(2, 3) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model() >>> state_map = brainstate.graph.treefy_states(model) >>> param, others = state_map.treefy_split(brainstate.ParamState, ...)
- Parameters:
first – The first filter
*filters (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – The optional, additional filters to group the state into mutually exclusive substates.
- Return type:
NestedDict|Tuple[NestedDict,...]- Returns:
One or more
Statesequal to the number of filters passed.
- to_flat()[source]#
Flatten the nested mapping into a flat mapping.
- Return type:
- Returns:
The flattened mapping.
- to_pure_dict()[source]#
Convert to a pure nested dictionary structure.
This method creates a standard Python dictionary with the same nested structure as this NestedDict, without any special class wrappers.
Example
>>> nested = NestedDict({'a': {'b': 1, 'c': 2}}) >>> pure = nested.to_pure_dict() >>> type(pure) <class 'dict'>