FrozenDict#

class brainstate.util.FrozenDict(*args, **kwargs)[source]#

An immutable dictionary that works as a JAX pytree.

FrozenDict provides an immutable mapping interface that can be used safely with JAX transformations. It supports all standard dictionary operations in an immutable fashion.

Parameters:
  • *args – Positional arguments for dict construction.

  • **kwargs – Keyword arguments for dict construction.

_data#

Internal immutable data storage.

Type:

dict

_hash#

Cached hash value.

Type:

int or None

See also

freeze

Convert a mapping to a FrozenDict.

unfreeze

Convert a FrozenDict to a regular dict.

Notes

FrozenDict is immutable - all operations that would modify the dictionary instead return a new FrozenDict instance with the changes applied.

Examples

>>> from brainstate.util import FrozenDict

>>> # Create a FrozenDict
>>> fd = FrozenDict({'a': 1, 'b': 2})
>>> fd['a']
1

>>> # Copy with updates (returns new FrozenDict)
>>> new_fd = fd.copy({'c': 3})
>>> new_fd['c']
3

>>> # Pop an item (returns new dict and popped value)
>>> new_fd, value = fd.pop('b')
>>> value
2
>>> 'b' in new_fd
False

>>> # Nested dictionaries are automatically frozen
>>> fd = FrozenDict({'x': {'y': 1}})
>>> isinstance(fd['x'], FrozenDict)
True
copy(add_or_replace=None)[source]#

Create a new FrozenDict with added or replaced entries.

Parameters:

add_or_replace (Mapping[TypeVar(K), TypeVar(V)] | None) – Entries to add or replace in the new dictionary.

Returns:

A new FrozenDict with the updates applied.

Return type:

FrozenDict[TypeVar(K), TypeVar(V)]

Examples

>>> fd = FrozenDict({'a': 1, 'b': 2})
>>> fd2 = fd.copy({'b': 3, 'c': 4})
>>> fd2['b'], fd2['c']
(3, 4)
get(key, default=None)[source]#

Get a value with a default.

Parameters:
  • key (TypeVar(K)) – The key to look up.

  • default (TypeVar(V) | None) – The default value to return if key is not found.

Returns:

The value associated with the key, or default.

Return type:

TypeVar(V) | None

items()[source]#

Return a view of the items.

Yields:

tuple – Key-value pairs from the dictionary.

Return type:

ItemsView[TypeVar(K), TypeVar(V)]

keys()[source]#

Return a view of the keys.

Returns:

A view object of the dictionary’s keys.

Return type:

KeysView[TypeVar(K)]

pop(key)[source]#

Create a new FrozenDict with one entry removed.

Parameters:

key (TypeVar(K)) – The key to remove.

Returns:

A tuple of (new FrozenDict without the key, removed value).

Return type:

tuple[FrozenDict[TypeVar(K), TypeVar(V)], TypeVar(V)]

Raises:

KeyError – If the key is not found in the dictionary.

Examples

>>> fd = FrozenDict({'a': 1, 'b': 2})
>>> fd2, value = fd.pop('a')
>>> value
1
>>> 'a' in fd2
False
pretty_repr(indent=2)[source]#

Return a pretty-printed representation.

Parameters:

indent (int) – Number of spaces per indentation level (default 2).

Returns:

A formatted string representation of the FrozenDict.

Return type:

str

tree_flatten_with_keys()[source]#

Flatten for JAX pytree with keys.

Return type:

tuple[list[tuple[Any, Any]], tuple[Any, ...]]

classmethod tree_unflatten(keys, values)[source]#

Unflatten from JAX pytree.

Return type:

FrozenDict

unfreeze()[source]#

Convert to a regular mutable dictionary.

Returns:

A mutable dict with the same contents.

Return type:

dict[TypeVar(K), TypeVar(V)]

Examples

>>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
>>> d = fd.unfreeze()
>>> isinstance(d, dict)
True
>>> isinstance(d['b'], dict)  # Nested dicts also unfrozen
True
values()[source]#

Return a view of the values.

Returns:

A view object of the dictionary’s values.

Return type:

ValuesView[TypeVar(V)]