FlattedDict#
- class brainstate.util.FlattedDict#
A pytree-like structure that contains a
Mappingfrom strings or integers to leaves.A valid leaf type is either
State,jax.Array,numpy.ndarrayor Python variables.A
NestedDictcan be generated by either callingstates()ornodes()on theModule.Example usage:
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> class Model(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.batchnorm = brainstate.nn.BatchNorm1d([10, 3]) ... self.linear = brainstate.nn.Linear(2, 3) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> >>> model = Model() >>> # retrieve the states of the model >>> model.states() # with the same to the function of ``brainstate.graph.states()`` FlattedDict({ ('batchnorm', 'running_mean'): LongTermState( value=Array([[0., 0., 0.]], dtype=float32) ), ('batchnorm', 'running_var'): LongTermState( value=Array([[1., 1., 1.]], dtype=float32) ), ('batchnorm', 'weight'): ParamState( value={'bias': Array([[0., 0., 0.]], dtype=float32), 'scale': Array([[1., 1., 1.]], dtype=float32)} ), ('linear', 'weight'): ParamState( value={'weight': Array([[-0.21467684, 0.7621282 , -0.50756454, -0.49047297], [-0.90413696, 0.6711 , -0.1254792 , 0.50412565], [ 0.23975602, 0.47905368, 1.4851435 , 0.16745673]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)} ) }) >>> # retrieve the nodes of the model >>> model.nodes() # with the same to the function of ``brainstate.graph.nodes()`` FlattedDict({ ('batchnorm',): BatchNorm1d( in_size=(10, 3), out_size=(10, 3), affine=True, bias_initializer=Constant(value=0.0, dtype=<class 'numpy.float32'>), scale_initializer=Constant(value=1.0, dtype=<class 'numpy.float32'>), dtype=<class 'numpy.float32'>, track_running_stats=True, momentum=Array(shape=(), dtype=float32), epsilon=Array(shape=(), dtype=float32), feature_axis=(1,), axis_name=None, axis_index_groups=None, running_mean=LongTermState( value=Array(shape=(1, 3), dtype=float32) ), running_var=LongTermState( value=Array(shape=(1, 3), dtype=float32) ), weight=ParamState( value={'bias': Array(shape=(1, 3), dtype=float32), 'scale': Array(shape=(1, 3), dtype=float32)} ) ), ('linear',): Linear( in_size=(10, 3), out_size=(10, 4), w_mask=None, weight=ParamState( value={'bias': Array(shape=(4,), dtype=float32), 'weight': Array(shape=(3, 4), dtype=float32)} ) ), (): Model( batchnorm=BatchNorm1d(...), linear=Linear(...) ) })
- assign_dict_values(data)[source]#
Assign values from a dictionary to this FlattedDict.
This method updates the values in the FlattedDict with values from the provided dictionary. For keys that correspond to
Stateobjects, thevalueattribute of the State is updated. For other keys, the value in the FlattedDict is directly replaced with the new value.- Parameters:
data (
Dict[Tuple[Key,...],Any]) – A dictionary containing the values to assign, where keys must match those in the FlattedDict.- Raises:
KeyError – If a key in the FlattedDict is not present in the provided dictionary.
- Return type:
Example
>>> from brainstate._state import ParamState >>> flat_dict = FlattedDict({ ... ('model', 'weight'): ParamState(value=jnp.zeros((5, 5))) ... }) >>> flat_dict.assign_dict_values({('model', 'weight'): jnp.ones((5, 5))}) # The ParamState's value is now an array of ones
- filter(first, /, *filters)[source]#
Filter a
FlattedDictinto one or moreFlattedDict’s. The user must pass at least one :class:`Filter (i.e.State). This method is similar tosplit(), except the filters can be non-exhaustive.- Parameters:
first (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – The first filter*filters (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – The optional, additional filters to group the state into mutually exclusive substates.
- Return type:
- Returns:
One or more
Statesequal to the number of filters passed.
- classmethod from_nest(nested_dict)[source]#
Create a
NestedDictfrom a flat mapping.
- static merge(*states)[source]#
The inverse of
split().mergetakes one or moreFlattedDict’s and creates a newFlattedDict.- Parameters:
state – A
PrettyDictobject.*states (
FlattedDict|NestedDict) – AdditionalPrettyDictobjects.
- Return type:
- Returns:
The merged
PrettyDict.
- split(first, /, *filters)[source]#
Split a
FlattedDictinto one or moreFlattedDict’s. The user must pass at least one :class:`Filter (i.e.State), and the filters must be exhaustive (i.e. they must cover allStatetypes in theNestedDict).- Parameters:
first (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – The first filter*filters (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – The optional, additional filters to group the state into mutually exclusive substates.
- Return type:
- Returns:
One or more
Statesequal to the number of filters passed.
- to_dict_values()[source]#
Convert a FlattedDict containing State objects to a dictionary of raw values.
This method extracts the underlying values from any
Stateobjects in the FlattedDict, creating a new dictionary with the same keys but where each State object is replaced by itsvalueattribute. Non-State objects are kept as is.- Returns:
- A dictionary with the same keys as the FlattedDict, but
where each State object is replaced by its value. Non-State objects remain unchanged.
- Return type:
Example
>>> from brainstate._state import ParamState >>> flat_dict = FlattedDict({ ... ('model', 'layer1', 'weight'): ParamState(value=jnp.ones((10, 5))) ... }) >>> values = flat_dict.to_dict_values() >>> values[('model', 'layer1', 'weight')] Array([[1., 1., ...]], dtype=float32)