Utility Toolkit#
The brainstate.util package bundles helpers for collections, structured
PyTrees, pretty-printing, and runtime hygiene. This notebook walks through
the most frequently used APIs with runnable examples.
Sections:
Scheduling and naming helpers
Memory housekeeping
Managing collections with
DictManagerConfiguration access via
DotDictDictionary utilities (
merge,flatten,unflatten)Structured PyTrees with
util.structFiltering nested objects
Pretty PyTree containers
from typing import Any
import jax
import jax.numpy as jnp
from brainstate import util
from brainstate.util import (
DictManager,
DotDict,
clear_buffer_memory,
flatten_dict,
merge_dicts,
split_total,
unflatten_dict,
)
from brainstate.util import struct, filter as util_filter
1. Scheduling and naming helpers#
split_total calculates a portion of work either from a fractional quota or
an absolute count. get_unique_name keeps thread-local counters so repeated
calls stay unique without manual bookkeeping.
epochs = split_total(total=120, fraction=0.25)
override = split_total(total=120, fraction=30)
print('fractional schedule:', epochs)
print('absolute schedule:', override)
names = [util.get_unique_name('layer') for _ in range(3)]
scoped = [util.get_unique_name('block', prefix='encoder_') for _ in range(2)]
print('names:', names)
print('scoped names:', scoped)
2. Memory housekeeping#
clear_buffer_memory makes it easy to release cached device buffers and
compilation artifacts between experiments. Passing array=False keeps this
example side-effect free while illustrating the API.
clear_buffer_memory(array=False)
print('Cleared JAX compilation caches and triggered GC.')
3. Managing collections with DictManager#
DictManager extends the standard mapping interface with filters, splits,
combination operators, and JAX PyTree support.
modules = DictManager({
'encoder': {'params': 32},
'decoder': {'params': 45},
'dropout': 0.1,
})
print('original:', modules)
# Filter only submodules (dict instances)
submods = modules.subset(dict)
print('subset:', submods)
# Split by type: dict entries vs everything else
dicts, remainder = modules.split(dict)
print('split dicts:', dicts)
print('split remainder:', remainder)
# Map over values to extract parameter counts
param_counts = submods.map_values(lambda layer: layer['params'])
param_counts
4. Configuration access via DotDict#
DotDict lets you treat nested dictionaries like lightweight objects while
preserving conversion back to standard dicts when needed.
config = DotDict({
'model': {
'layers': 4,
'hidden': 256,
},
'training': {
'lr': 3e-4,
'scheduler': {'warmup_steps': 500},
},
})
print('hidden units:', config.model.hidden)
config.training.dropout = 0.2
print('with dropout:', config.training.dropout)
round_trip = config.to_dict()
round_trip
5. Dictionary utilities#
merge_dicts performs optional recursive merges. flatten_dict and
unflatten_dict convert between nested and dotted-key representations—useful
for logging or CLI overrides.
base = {'optimizer': {'lr': 1e-3, 'beta1': 0.9}}
override = {'optimizer': {'lr': 5e-4}, 'seed': 1234}
merged = merge_dicts(base, override)
print('merged:', merged)
flat = flatten_dict(merged)
print('flattened:', flat)
unflatten_dict(flat)
6. Structured PyTrees with util.struct#
The struct submodule mirrors Flax-friendly data structures. The
dataclass decorator registers classes as PyTrees, while FrozenDict
provides immutable mappings compatible with JAX transformations.
@struct.dataclass
class LayerConfig:
weight: jax.Array
bias: jax.Array
name: str = struct.field(pytree_node=False, default='layer')
cfg = LayerConfig(weight=jnp.ones((2, 2)), bias=jnp.zeros(2))
print(cfg)
cfg2 = cfg.replace(weight=jnp.full((2, 2), 3.0))
print('updated weight:', cfg2.weight)
flat_leaves, _ = jax.tree_util.tree_flatten(cfg)
print('pytree leaves:', [leaf.shape for leaf in flat_leaves])
frozen = struct.freeze({'encoder': jnp.arange(3)})
print('frozen dict:', frozen)
print('unfrozen:', struct.unfreeze(frozen))
7. Filtering nested objects#
brainstate.util.filter turns declarative filters into callables. Combine tag,
type, and path checks when traversing parameter trees.
class Module:
def __init__(self, tag: str | None, kind: str):
self.tag = tag
self.kind = kind
self.params = jnp.arange(2)
model_tree = {
'encoder': Module(tag='trainable', kind='linear'),
'decoder': Module(tag='frozen', kind='linear'),
'head': Module(tag='trainable', kind='mlp'),
}
tag_filter = util_filter.to_predicate('trainable')
type_filter = util_filter.OfType(Module)
combined = util_filter.All(type_filter, util_filter.WithTag('trainable'))
def collect(tree: dict[str, Any], predicate) -> dict[str, Any]:
out = {}
for key, value in tree.items():
if predicate((key,), value):
out[key] = value
return out
trainable_modules = collect(model_tree, tag_filter)
both = collect(model_tree, lambda path, val: combined(path, val))
print('trainable keys:', tuple(trainable_modules.keys()))
print('trainable Modules:', tuple(both.keys()))
8. Pretty PyTree containers#
NestedDict, FlattedDict, and PrettyList bring readable reprs plus PyTree
semantics. Use them to explore checkpoints or log structured configs.
from brainstate.util import NestedDict, flat_mapping, nest_mapping, PrettyList
state = NestedDict({
'encoder': {'weight': jnp.ones((2, 2)), 'bias': jnp.zeros(2)},
'decoder': {'weight': jnp.eye(2)},
})
print(state)
flat_state = flat_mapping(state)
print('flat keys:', list(flat_state.keys()))
round_trip = nest_mapping(flat_state)
print('round-trip equal:', round_trip == state)
history = PrettyList([{'loss': 0.8}, {'loss': 0.42}])
print(history)
Summary#
Use scheduling helpers (
split_total,get_unique_name) to coordinate experiments.Reach for
DictManagerandDotDictto manage nested collections.Convert between nested and flat configs with
merge_dicts,flatten_dict, andunflatten_dict.Wrap structured data using
util.structand leverage filter/pretty utilities when exploring PyTrees.