assign_state_values

assign_state_values#

class brainstate.nn.assign_state_values(target, *state_by_abs_path)#

Assign state values to a module from one or more state dictionaries.

This function updates the state values of a module based on provided state dictionaries. State dictionaries should use absolute paths as keys (e.g., ‘layer1.weight’, ‘layer2.bias’). The function handles missing and unexpected keys, returning them for inspection.

Parameters:
  • target (Module) – The target module whose states will be updated.

  • *state_by_abs_path (Mapping[str, Any]) – One or more state dictionaries with absolute path keys mapping to state values. If multiple dictionaries are provided, they will be merged (later dictionaries override earlier ones for duplicate keys).

Returns:

A tuple of (unexpected_keys, missing_keys):

  • unexpected_keys: Keys present in the state dictionaries but not in the module

  • missing_keys: Keys present in the module but not in the state dictionaries

Return type:

tuple[list[str], list[str]]

Examples

>>> import brainstate
>>>
>>> net = brainstate.nn.Linear(10, 20)
>>> brainstate.nn.init_all_states(net)
>>>
>>> # Save state values
>>> state_dict = {path: state.value for path, state in net.states().items()}
>>>
>>> # Later, restore state values
>>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict)
>>> print(f"Unexpected keys: {unexpected}")
>>> print(f"Missing keys: {missing}")

Notes

  • All values are automatically converted to JAX arrays using jax.numpy.asarray.

  • Only states with matching keys are updated; unexpected and missing keys are returned but do not cause errors.

  • If multiple dictionaries contain the same key, the last one takes precedence.