Utility Toolkit#

The brainstate.util package bundles helpers for collections, structured PyTrees, pretty-printing, and runtime hygiene. This notebook walks through the most frequently used APIs with runnable examples.

Sections:

  1. Scheduling and naming helpers

  2. Memory housekeeping

  3. Managing collections with DictManager

  4. Configuration access via DotDict

  5. Dictionary utilities (merge, flatten, unflatten)

  6. Structured PyTrees with util.struct

  7. Filtering nested objects

  8. Pretty PyTree containers

from typing import Any

import jax
import jax.numpy as jnp

from brainstate import util
from brainstate.util import (
    DictManager,
    DotDict,
    clear_buffer_memory,
    flatten_dict,
    merge_dicts,
    split_total,
    unflatten_dict,
)

from brainstate.util import struct, filter as util_filter

1. Scheduling and naming helpers#

split_total calculates a portion of work either from a fractional quota or an absolute count. get_unique_name keeps thread-local counters so repeated calls stay unique without manual bookkeeping.

epochs = split_total(total=120, fraction=0.25)
override = split_total(total=120, fraction=30)
print('fractional schedule:', epochs)
print('absolute schedule:', override)

names = [util.get_unique_name('layer') for _ in range(3)]
scoped = [util.get_unique_name('block', prefix='encoder_') for _ in range(2)]
print('names:', names)
print('scoped names:', scoped)

2. Memory housekeeping#

clear_buffer_memory makes it easy to release cached device buffers and compilation artifacts between experiments. Passing array=False keeps this example side-effect free while illustrating the API.

clear_buffer_memory(array=False)
print('Cleared JAX compilation caches and triggered GC.')

3. Managing collections with DictManager#

DictManager extends the standard mapping interface with filters, splits, combination operators, and JAX PyTree support.

modules = DictManager({
    'encoder': {'params': 32},
    'decoder': {'params': 45},
    'dropout': 0.1,
})
print('original:', modules)

# Filter only submodules (dict instances)
submods = modules.subset(dict)
print('subset:', submods)

# Split by type: dict entries vs everything else
dicts, remainder = modules.split(dict)
print('split dicts:', dicts)
print('split remainder:', remainder)

# Map over values to extract parameter counts
param_counts = submods.map_values(lambda layer: layer['params'])
param_counts

4. Configuration access via DotDict#

DotDict lets you treat nested dictionaries like lightweight objects while preserving conversion back to standard dicts when needed.

config = DotDict({
    'model': {
        'layers': 4,
        'hidden': 256,
    },
    'training': {
        'lr': 3e-4,
        'scheduler': {'warmup_steps': 500},
    },
})

print('hidden units:', config.model.hidden)
config.training.dropout = 0.2
print('with dropout:', config.training.dropout)

round_trip = config.to_dict()
round_trip

5. Dictionary utilities#

merge_dicts performs optional recursive merges. flatten_dict and unflatten_dict convert between nested and dotted-key representations—useful for logging or CLI overrides.

base = {'optimizer': {'lr': 1e-3, 'beta1': 0.9}}
override = {'optimizer': {'lr': 5e-4}, 'seed': 1234}
merged = merge_dicts(base, override)
print('merged:', merged)

flat = flatten_dict(merged)
print('flattened:', flat)
unflatten_dict(flat)

6. Structured PyTrees with util.struct#

The struct submodule mirrors Flax-friendly data structures. The dataclass decorator registers classes as PyTrees, while FrozenDict provides immutable mappings compatible with JAX transformations.

@struct.dataclass
class LayerConfig:
    weight: jax.Array
    bias: jax.Array
    name: str = struct.field(pytree_node=False, default='layer')

cfg = LayerConfig(weight=jnp.ones((2, 2)), bias=jnp.zeros(2))
print(cfg)

cfg2 = cfg.replace(weight=jnp.full((2, 2), 3.0))
print('updated weight:', cfg2.weight)

flat_leaves, _ = jax.tree_util.tree_flatten(cfg)
print('pytree leaves:', [leaf.shape for leaf in flat_leaves])

frozen = struct.freeze({'encoder': jnp.arange(3)})
print('frozen dict:', frozen)
print('unfrozen:', struct.unfreeze(frozen))

7. Filtering nested objects#

brainstate.util.filter turns declarative filters into callables. Combine tag, type, and path checks when traversing parameter trees.

class Module:
    def __init__(self, tag: str | None, kind: str):
        self.tag = tag
        self.kind = kind
        self.params = jnp.arange(2)

model_tree = {
    'encoder': Module(tag='trainable', kind='linear'),
    'decoder': Module(tag='frozen', kind='linear'),
    'head': Module(tag='trainable', kind='mlp'),
}

tag_filter = util_filter.to_predicate('trainable')
type_filter = util_filter.OfType(Module)
combined = util_filter.All(type_filter, util_filter.WithTag('trainable'))

def collect(tree: dict[str, Any], predicate) -> dict[str, Any]:
    out = {}
    for key, value in tree.items():
        if predicate((key,), value):
            out[key] = value
    return out

trainable_modules = collect(model_tree, tag_filter)
both = collect(model_tree, lambda path, val: combined(path, val))
print('trainable keys:', tuple(trainable_modules.keys()))
print('trainable Modules:', tuple(both.keys()))

8. Pretty PyTree containers#

NestedDict, FlattedDict, and PrettyList bring readable reprs plus PyTree semantics. Use them to explore checkpoints or log structured configs.

from brainstate.util import NestedDict, flat_mapping, nest_mapping, PrettyList

state = NestedDict({
    'encoder': {'weight': jnp.ones((2, 2)), 'bias': jnp.zeros(2)},
    'decoder': {'weight': jnp.eye(2)},
})
print(state)

flat_state = flat_mapping(state)
print('flat keys:', list(flat_state.keys()))

round_trip = nest_mapping(flat_state)
print('round-trip equal:', round_trip == state)

history = PrettyList([{'loss': 0.8}, {'loss': 0.42}])
print(history)

Summary#

  • Use scheduling helpers (split_total, get_unique_name) to coordinate experiments.

  • Reach for DictManager and DotDict to manage nested collections.

  • Convert between nested and flat configs with merge_dicts, flatten_dict, and unflatten_dict.

  • Wrap structured data using util.struct and leverage filter/pretty utilities when exploring PyTrees.