FlattedDict#

class brainstate.util.FlattedDict#

A pytree-like structure that contains a Mapping from strings or integers to leaves.

A valid leaf type is either State, jax.Array, numpy.ndarray or Python variables.

A NestedDict can be generated by either calling states() or nodes() on the Module.

Example usage:

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> class Model(brainstate.nn.Module):
...   def __init__(self):
...     super().__init__()
...     self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
...     self.linear = brainstate.nn.Linear(2, 3)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))
>>>
>>> model = Model()

>>> # retrieve the states of the model
>>> model.states()  # with the same to the function of ``brainstate.graph.states()``
FlattedDict({
  ('batchnorm', 'running_mean'): LongTermState(
    value=Array([[0., 0., 0.]], dtype=float32)
  ),
  ('batchnorm', 'running_var'): LongTermState(
    value=Array([[1., 1., 1.]], dtype=float32)
  ),
  ('batchnorm', 'weight'): ParamState(
    value={'bias': Array([[0., 0., 0.]], dtype=float32), 'scale': Array([[1., 1., 1.]], dtype=float32)}
  ),
  ('linear', 'weight'): ParamState(
    value={'weight': Array([[-0.21467684,  0.7621282 , -0.50756454, -0.49047297],
           [-0.90413696,  0.6711    , -0.1254792 ,  0.50412565],
           [ 0.23975602,  0.47905368,  1.4851435 ,  0.16745673]],      dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)}
  )
})

>>> # retrieve the nodes of the model
>>> model.nodes()  # with the same to the function of ``brainstate.graph.nodes()``
FlattedDict({
  ('batchnorm',): BatchNorm1d(
    in_size=(10, 3),
    out_size=(10, 3),
    affine=True,
    bias_initializer=Constant(value=0.0, dtype=<class 'numpy.float32'>),
    scale_initializer=Constant(value=1.0, dtype=<class 'numpy.float32'>),
    dtype=<class 'numpy.float32'>,
    track_running_stats=True,
    momentum=Array(shape=(), dtype=float32),
    epsilon=Array(shape=(), dtype=float32),
    feature_axis=(1,),
    axis_name=None,
    axis_index_groups=None,
    running_mean=LongTermState(
      value=Array(shape=(1, 3), dtype=float32)
    ),
    running_var=LongTermState(
      value=Array(shape=(1, 3), dtype=float32)
    ),
    weight=ParamState(
      value={'bias': Array(shape=(1, 3), dtype=float32), 'scale': Array(shape=(1, 3), dtype=float32)}
    )
  ),
  ('linear',): Linear(
    in_size=(10, 3),
    out_size=(10, 4),
    w_mask=None,
    weight=ParamState(
      value={'bias': Array(shape=(4,), dtype=float32), 'weight': Array(shape=(3, 4), dtype=float32)}
    )
  ),
  (): Model(
    batchnorm=BatchNorm1d(...),
    linear=Linear(...)
  )
})
assign_dict_values(data)[source]#

Assign values from a dictionary to this FlattedDict.

This method updates the values in the FlattedDict with values from the provided dictionary. For keys that correspond to State objects, the value attribute of the State is updated. For other keys, the value in the FlattedDict is directly replaced with the new value.

Parameters:

data (Dict[Tuple[Key, ...], Any]) – A dictionary containing the values to assign, where keys must match those in the FlattedDict.

Raises:

KeyError – If a key in the FlattedDict is not present in the provided dictionary.

Return type:

None

Example

>>> from brainstate._state import ParamState
>>> flat_dict = FlattedDict({
...     ('model', 'weight'): ParamState(value=jnp.zeros((5, 5)))
... })
>>> flat_dict.assign_dict_values({('model', 'weight'): jnp.ones((5, 5))})
# The ParamState's value is now an array of ones
filter(first, /, *filters)[source]#

Filter a FlattedDict into one or more FlattedDict’s. The user must pass at least one :class:`Filter (i.e. State). This method is similar to split(), except the filters can be non-exhaustive.

Parameters:
Return type:

FlattedDict | Tuple[FlattedDict, ...]

Returns:

One or more States equal to the number of filters passed.

classmethod from_nest(nested_dict)[source]#

Create a NestedDict from a flat mapping.

Parameters:

nested_dict (Mapping[Tuple[Key, ...], TypeVar(V)] | Iterable[tuple[Tuple[Key, ...], TypeVar(V)]]) – The flat mapping.

Return type:

FlattedDict

Returns:

The NestedDict.

static merge(*states)[source]#

The inverse of split().

merge takes one or more FlattedDict’s and creates a new FlattedDict.

Parameters:
Return type:

FlattedDict

Returns:

The merged PrettyDict.

split(first, /, *filters)[source]#

Split a FlattedDict into one or more FlattedDict’s. The user must pass at least one :class:`Filter (i.e. State), and the filters must be exhaustive (i.e. they must cover all State types in the NestedDict).

Parameters:
Return type:

FlattedDict | tuple[FlattedDict, ...]

Returns:

One or more States equal to the number of filters passed.

to_dict_values()[source]#

Convert a FlattedDict containing State objects to a dictionary of raw values.

This method extracts the underlying values from any State objects in the FlattedDict, creating a new dictionary with the same keys but where each State object is replaced by its value attribute. Non-State objects are kept as is.

Returns:

A dictionary with the same keys as the FlattedDict, but

where each State object is replaced by its value. Non-State objects remain unchanged.

Return type:

Dict[Tuple[Key, ...], Any]

Example

>>> from brainstate._state import ParamState
>>> flat_dict = FlattedDict({
...     ('model', 'layer1', 'weight'): ParamState(value=jnp.ones((10, 5)))
... })
>>> values = flat_dict.to_dict_values()
>>> values[('model', 'layer1', 'weight')]
Array([[1., 1., ...]], dtype=float32)
to_nest()[source]#

Unflatten the flat mapping into a nested mapping.

Return type:

NestedDict

Returns:

The nested mapping.