UniqueStateManager#

class braintools.optim.UniqueStateManager(pytree=None)#

A class to manage unique State objects in a PyTree structure.

This class: 1. Flattens a PyTree with State leaves to (path, leaf) pairs 2. Removes duplicate State objects based on their id() 3. Supports recovering the unique states back to a PyTree structure

Example usage:
>>> import jax.numpy as jnp
>>> from brainstate import ParamState
>>>
>>> # Create a pytree with some duplicate State objects
>>> state1 = ParamState(jnp.ones((2, 3)))
>>> state2 = ParamState(jnp.zeros((3, 4)))
>>>
>>> pytree = {
...     'layer1': {'weight': state1, 'bias': state2},
...     'layer2': {'weight': state1, 'bias': ParamState(jnp.ones((4,)))}  # state1 is duplicate
... }
>>>
>>> # Create manager and process the pytree
>>> manager = UniqueStateManager()
>>> unique_pytree = manager.make_unique(pytree)
>>>
>>> # The duplicate state1 in layer2.weight will be removed
>>> print(len(manager.flattened_states))  # Will be 3 instead of 4
>>>
>>> # Recover to pytree structure
>>> recovered = manager.to_pytree()
clear()[source]#

Clear all stored states and reset the manager.

get_flattened()[source]#

Get the flattened list of (path, state) pairs.

Return type:

List[Tuple[Any, State]]

Returns:

List of tuples containing (path, State) for unique states

get_state_by_path(target_path)[source]#

Retrieve a State object by its path.

Parameters:

target_path (Any) – The path to the desired State

Return type:

State

Returns:

The State object at the given path, or None if not found

make_unique(pytree)[source]#

Process a PyTree with State leaves and remove duplicates.

Parameters:

pytree (State]) – A PyTree where leaves are State objects

Return type:

State]

Returns:

A PyTree with only unique State objects (duplicates removed)

merge_with(other_pytree)[source]#

Merge another PyTree with the current unique states, maintaining uniqueness.

Parameters:

other_pytree (State]) – Another PyTree with State leaves to merge

Return type:

UniqueStateManager

Returns:

Merged PyTree with all unique State objects

property num_unique_states: int#

Get the number of unique State objects.

to_dict()[source]#

Convert the stored unique states to a dictionary with path strings as keys.

Return type:

Dict[str, State]

Returns:

Dictionary where keys are string representations of paths and values are State objects

to_dict_value()[source]#

Convert the stored unique states to a dictionary with path strings as keys and State.value as values.

Return type:

Dict[str, Any]

Returns:

Dictionary where keys are string representations of paths and values are State.value

to_pytree()[source]#

Convert the stored unique states back to a PyTree structure.

Return type:

State]

Returns:

PyTree with unique State objects

to_pytree_value()[source]#

Convert the stored unique states to a PyTree with State.value as leaves.

Return type:

PyTree

Returns:

PyTree where leaves are the values (State.value) of the State objects

update_state(target_path, new_state)[source]#

Update a State object at a specific path.

Parameters:
  • target_path (Any) – The path to the State to update

  • new_state (State) – The new State object

Return type:

bool

Returns:

True if update was successful, False if path not found