UniqueStateManager#
- class braintools.optim.UniqueStateManager(pytree=None)#
A class to manage unique State objects in a PyTree structure.
This class: 1. Flattens a PyTree with State leaves to (path, leaf) pairs 2. Removes duplicate State objects based on their id() 3. Supports recovering the unique states back to a PyTree structure
- Example usage:
>>> import jax.numpy as jnp >>> from brainstate import ParamState >>> >>> # Create a pytree with some duplicate State objects >>> state1 = ParamState(jnp.ones((2, 3))) >>> state2 = ParamState(jnp.zeros((3, 4))) >>> >>> pytree = { ... 'layer1': {'weight': state1, 'bias': state2}, ... 'layer2': {'weight': state1, 'bias': ParamState(jnp.ones((4,)))} # state1 is duplicate ... } >>> >>> # Create manager and process the pytree >>> manager = UniqueStateManager() >>> unique_pytree = manager.make_unique(pytree) >>> >>> # The duplicate state1 in layer2.weight will be removed >>> print(len(manager.flattened_states)) # Will be 3 instead of 4 >>> >>> # Recover to pytree structure >>> recovered = manager.to_pytree()
- get_state_by_path(target_path)[source]#
Retrieve a State object by its path.
- Parameters:
target_path (
Any) – The path to the desired State- Return type:
State- Returns:
The State object at the given path, or None if not found
- make_unique(pytree)[source]#
Process a PyTree with State leaves and remove duplicates.
- Parameters:
pytree (
State]) – A PyTree where leaves are State objects- Return type:
State]- Returns:
A PyTree with only unique State objects (duplicates removed)
- merge_with(other_pytree)[source]#
Merge another PyTree with the current unique states, maintaining uniqueness.
- Parameters:
other_pytree (
State]) – Another PyTree with State leaves to merge- Return type:
- Returns:
Merged PyTree with all unique State objects
- to_dict_value()[source]#
Convert the stored unique states to a dictionary with path strings as keys and State.value as values.
- to_pytree()[source]#
Convert the stored unique states back to a PyTree structure.
- Return type:
State]- Returns:
PyTree with unique State objects