DictManager#

class brainstate.util.DictManager(*args, **kwargs)#

Enhanced dictionary for managing collections in BrainState.

DictManager extends the standard Python dict with additional methods for filtering, splitting, and managing collections of objects. It’s registered as a JAX pytree node for compatibility with JAX transformations.

Examples

>>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
>>> dm.subset(int)  # Get only integer values
DictManager({'a': 1})
>>> dm.unique()  # Get unique values only
DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
add_unique_key(key, val)[source]#

Add a new element ensuring the key maps to a unique value.

Parameters:
  • key (TypeVar(K)) – The key to add.

  • val (TypeVar(V)) – The value to associate with the key.

Raises:

ValueError – If the key already exists with a different value.

Return type:

None

add_unique_value(key, val)[source]#

Add a new element only if the value is unique across all entries.

Parameters:
  • key (TypeVar(K)) – The key to add.

  • val (TypeVar(V)) – The value to associate with the key.

Returns:

True if the value was added (was unique), False otherwise.

Return type:

bool

assign(*args, **kwargs)[source]#

Update the DictManager with multiple dictionaries.

Parameters:
  • *args (Dict[TypeVar(K), TypeVar(V)]) – Dictionaries to merge into this one.

  • **kwargs (TypeVar(V)) – Additional key-value pairs to add.

Return type:

None

difference_by_keys(keys)[source]#

Get items not in the specified keys.

Return type:

DictManager

difference_by_values(values, by='id')[source]#

Get items whose values are not in the specified collection.

Return type:

DictManager

filter_by_predicate(predicate)[source]#

Filter items using a predicate function.

Parameters:

predicate (Callable[[TypeVar(K), TypeVar(V)], bool]) – Function that returns True for items to keep.

Returns:

A new DictManager with filtered items.

Return type:

DictManager

intersection_by_keys(keys)[source]#

Get items with keys in the specified collection.

Return type:

DictManager

intersection_by_values(values, by='id')[source]#

Get items whose values are in the specified collection.

Return type:

DictManager

map_keys(func)[source]#

Apply a function to all keys.

Parameters:

func (Callable[[TypeVar(K)], Any]) – Function to apply to each key.

Returns:

A new DictManager with transformed keys.

Return type:

DictManager

Raises:

ValueError – If the transformation creates duplicate keys.

map_values(func)[source]#

Apply a function to all values.

Parameters:

func (Callable[[TypeVar(V)], Any]) – Function to apply to each value.

Returns:

A new DictManager with transformed values.

Return type:

DictManager

not_subset(sep)[source]#

Get a new DictManager excluding items of specified types.

Parameters:

sep (Type | Tuple[Type, ...]) – Types to exclude from the result.

Returns:

A new DictManager excluding items of specified types.

Return type:

DictManager

pop_by_keys(keys)[source]#

Remove multiple keys from the DictManager.

Return type:

None

pop_by_values(values, by='id')[source]#

Remove items by their values.

Parameters:
  • values (Iterable[TypeVar(V)]) – Values to remove.

  • by (str) – Comparison method: ‘id’ (identity) or ‘value’ (equality).

Return type:

None

split(*types)[source]#

Split the DictManager into multiple based on value types.

Parameters:

*types (Type) – Types to use for splitting. Each type gets its own DictManager.

Returns:

A tuple of DictManagers, one for each type plus one for unmatched items.

Return type:

Tuple[DictManager, ...]

subset(sep)[source]#

Get a new DictManager with a subset of items based on value type or predicate.

Parameters:

sep (Type | Tuple[Type, ...] | Callable[[Any], bool]) – If Type or Tuple of Types: Select values that are instances of these types. If Callable: Select values where sep(value) returns True.

Returns:

A new DictManager containing only matching items.

Return type:

DictManager

Examples

>>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
>>> dm.subset(int)
DictManager({'a': 1})
>>> dm.subset(lambda x: isinstance(x, (int, float)))
DictManager({'a': 1, 'b': 2.0})
to_dict()[source]#

Convert to a standard Python dict.

Return type:

Dict[TypeVar(K), TypeVar(V)]

tree_flatten()[source]#

Flatten for JAX pytree.

Return type:

Tuple[Tuple[TypeVar(V), ...], Tuple[TypeVar(K), ...]]

classmethod tree_unflatten(keys, values)[source]#

Unflatten from JAX pytree.

Return type:

DictManager

unique()[source]#

Get a new DictManager with unique values only.

If multiple keys map to the same value (by identity), only the first key-value pair is retained.

Returns:

A new DictManager with unique values.

Return type:

DictManager

unique_()[source]#

Remove duplicate values in-place.

Returns:

Self, for method chaining.

Return type:

DictManager