NestedDict#

class brainstate.util.NestedDict#

A pytree-like nested mapping structure for organizing hierarchical data.

This class represents a nested mapping from strings or integers to leaves, where valid leaf types include State, jax.Array, numpy.ndarray, or nested NestedDict and FlattedDict structures.

NestedDict is a JAX pytree and can be used with JAX transformations. It provides methods for flattening to FlattedDict, splitting/filtering based on predicates, and merging multiple nested structures.

__module__#

Module identifier set to ‘brainstate.util’.

Type:

str

Example

>>> from brainstate.util import NestedDict
>>> state = NestedDict({
...     'layer1': {'weight': jnp.ones((3, 3)), 'bias': jnp.zeros(3)},
...     'layer2': {'weight': jnp.ones((3, 1))}
... })
>>> flat = state.to_flat()
>>> print(flat)
FlattedDict({('layer1', 'weight'): ..., ('layer1', 'bias'): ..., ...})

See also

FlattedDict: The flattened counterpart with tuple keys. flat_mapping(): Function to flatten a nested mapping. nest_mapping(): Function to unflatten a flat mapping.

filter(*filters)[source]#

Filter a NestedDict into one or more NestedDict’s. The user must pass at least one :class:`Filter (i.e. State). This method is similar to split(), except the filters can be non-exhaustive.

Parameters:
Return type:

NestedDict | Tuple[NestedDict, ...]

Returns:

One or more States equal to the number of filters passed.

classmethod from_flat(flat_dict)[source]#

Create a NestedDict from a flat mapping.

Parameters:

flat_dict (Mapping[Tuple[Key, ...], TypeVar(V)] | Iterable[tuple[Tuple[Key, ...], TypeVar(V)]]) – The flat mapping.

Return type:

NestedDict

Returns:

The NestedDict.

static merge(*states)[source]#

The inverse of split().

merge takes one or more PrettyDict’s and creates a new PrettyDict.

Parameters:

*states – Additional PrettyDict objects.

Return type:

NestedDict

Returns:

The merged PrettyDict.

replace_by_pure_dict(pure_dict, replace_fn=None)[source]#

Replace values in this NestedDict using a pure dictionary.

This method updates the values in this NestedDict with values from a standard Python dictionary. For State objects with a replace method, the replace method is called; otherwise, values are directly assigned.

Parameters:
  • pure_dict (Dict[str, Any]) – A pure dictionary with matching structure containing new values.

  • replace_fn (Callable[[TypeVar(V), Any], TypeVar(V)] | None) – Optional custom function to replace values. Takes (old_value, new_value) and returns the updated value. Defaults to calling replace() method if available, otherwise direct assignment.

Raises:

ValueError – If a key in pure_dict is not found in this NestedDict.

Return type:

None

Example

>>> from brainstate._state import State
>>> nested = NestedDict({'a': State(1), 'b': 2})
>>> nested.replace_by_pure_dict({'a': 10, 'b': 20})
>>> nested['a'].value
10
split(*filters)[source]#

Split a NestedDict into one or more NestedDict’s. The user must pass at least one :class:`Filter (i.e. State), and the filters must be exhaustive (i.e. they must cover all State types in the NestedDict).

Example usage:

>>> import brainstate as brainstate

>>> class Model(brainstate.nn.Module):
...   def __init__(self):
...     super().__init__()
...     self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
...     self.linear = brainstate.nn.Linear(2, 3)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model()
>>> state_map = brainstate.graph.treefy_states(model)
>>> param, others = state_map.treefy_split(brainstate.ParamState, ...)
Parameters:
Return type:

NestedDict | Tuple[NestedDict, ...]

Returns:

One or more States equal to the number of filters passed.

to_flat()[source]#

Flatten the nested mapping into a flat mapping.

Return type:

FlattedDict

Returns:

The flattened mapping.

to_pure_dict()[source]#

Convert to a pure nested dictionary structure.

This method creates a standard Python dictionary with the same nested structure as this NestedDict, without any special class wrappers.

Returns:

A pure nested dictionary representation.

Return type:

Dict[str, Any]

Example

>>> nested = NestedDict({'a': {'b': 1, 'c': 2}})
>>> pure = nested.to_pure_dict()
>>> type(pure)
<class 'dict'>