assign_state_values#
- class brainstate.nn.assign_state_values(target, *state_by_abs_path)#
Assign state values to a module from one or more state dictionaries.
This function updates the state values of a module based on provided state dictionaries. State dictionaries should use absolute paths as keys (e.g., ‘layer1.weight’, ‘layer2.bias’). The function handles missing and unexpected keys, returning them for inspection.
- Parameters:
target (
Module) – The target module whose states will be updated.*state_by_abs_path (
Mapping[str,Any]) – One or more state dictionaries with absolute path keys mapping to state values. If multiple dictionaries are provided, they will be merged (later dictionaries override earlier ones for duplicate keys).
- Returns:
A tuple of (unexpected_keys, missing_keys):
unexpected_keys: Keys present in the state dictionaries but not in the module
missing_keys: Keys present in the module but not in the state dictionaries
- Return type:
Examples
>>> import brainstate >>> >>> net = brainstate.nn.Linear(10, 20) >>> brainstate.nn.init_all_states(net) >>> >>> # Save state values >>> state_dict = {path: state.value for path, state in net.states().items()} >>> >>> # Later, restore state values >>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict) >>> print(f"Unexpected keys: {unexpected}") >>> print(f"Missing keys: {missing}")
Notes
All values are automatically converted to JAX arrays using jax.numpy.asarray.
Only states with matching keys are updated; unexpected and missing keys are returned but do not cause errors.
If multiple dictionaries contain the same key, the last one takes precedence.