{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Utility Toolkit",
   "id": "290dc8e5ad5655f4"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "The `brainstate.util` package bundles helpers for collections, structured\n",
    "PyTrees, pretty-printing, and runtime hygiene. This notebook walks through\n",
    "the most frequently used APIs with runnable examples.\n",
    "\n",
    "Sections:\n",
    "\n",
    "1. Scheduling and naming helpers\n",
    "2. Memory housekeeping\n",
    "3. Managing collections with `DictManager`\n",
    "4. Configuration access via `DotDict`\n",
    "5. Dictionary utilities (`merge`, `flatten`, `unflatten`)\n",
    "6. Structured PyTrees with `util.struct`\n",
    "7. Filtering nested objects\n",
    "8. Pretty PyTree containers"
   ],
   "id": "7d1bfe6168ad2b25"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from typing import Any\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from brainstate import util\n",
    "from brainstate.util import (\n",
    "    DictManager,\n",
    "    DotDict,\n",
    "    clear_buffer_memory,\n",
    "    flatten_dict,\n",
    "    merge_dicts,\n",
    "    split_total,\n",
    "    unflatten_dict,\n",
    ")\n",
    "\n",
    "from brainstate.util import struct, filter as util_filter"
   ],
   "id": "7ac40eefb7aa8ffb"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 1. Scheduling and naming helpers\n",
    "\n",
    "`split_total` calculates a portion of work either from a fractional quota or\n",
    "an absolute count. `get_unique_name` keeps thread-local counters so repeated\n",
    "calls stay unique without manual bookkeeping."
   ],
   "id": "c64a1b3296c10f55"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "epochs = split_total(total=120, fraction=0.25)\n",
    "override = split_total(total=120, fraction=30)\n",
    "print('fractional schedule:', epochs)\n",
    "print('absolute schedule:', override)\n",
    "\n",
    "names = [util.get_unique_name('layer') for _ in range(3)]\n",
    "scoped = [util.get_unique_name('block', prefix='encoder_') for _ in range(2)]\n",
    "print('names:', names)\n",
    "print('scoped names:', scoped)"
   ],
   "id": "785307a82d02fe20"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 2. Memory housekeeping\n",
    "\n",
    "`clear_buffer_memory` makes it easy to release cached device buffers and\n",
    "compilation artifacts between experiments. Passing `array=False` keeps this\n",
    "example side-effect free while illustrating the API."
   ],
   "id": "61af2e565d430bb4"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "clear_buffer_memory(array=False)\n",
    "print('Cleared JAX compilation caches and triggered GC.')"
   ],
   "id": "dbfe527da19a5381"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 3. Managing collections with `DictManager`\n",
    "\n",
    "`DictManager` extends the standard mapping interface with filters, splits,\n",
    "combination operators, and JAX PyTree support."
   ],
   "id": "e03ca8e3bd2469d1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "modules = DictManager({\n",
    "    'encoder': {'params': 32},\n",
    "    'decoder': {'params': 45},\n",
    "    'dropout': 0.1,\n",
    "})\n",
    "print('original:', modules)\n",
    "\n",
    "# Filter only submodules (dict instances)\n",
    "submods = modules.subset(dict)\n",
    "print('subset:', submods)\n",
    "\n",
    "# Split by type: dict entries vs everything else\n",
    "dicts, remainder = modules.split(dict)\n",
    "print('split dicts:', dicts)\n",
    "print('split remainder:', remainder)\n",
    "\n",
    "# Map over values to extract parameter counts\n",
    "param_counts = submods.map_values(lambda layer: layer['params'])\n",
    "param_counts"
   ],
   "id": "2b1845081b284dd0"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 4. Configuration access via `DotDict`\n",
    "\n",
    "`DotDict` lets you treat nested dictionaries like lightweight objects while\n",
    "preserving conversion back to standard dicts when needed."
   ],
   "id": "ae2524865e01b520"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "config = DotDict({\n",
    "    'model': {\n",
    "        'layers': 4,\n",
    "        'hidden': 256,\n",
    "    },\n",
    "    'training': {\n",
    "        'lr': 3e-4,\n",
    "        'scheduler': {'warmup_steps': 500},\n",
    "    },\n",
    "})\n",
    "\n",
    "print('hidden units:', config.model.hidden)\n",
    "config.training.dropout = 0.2\n",
    "print('with dropout:', config.training.dropout)\n",
    "\n",
    "round_trip = config.to_dict()\n",
    "round_trip"
   ],
   "id": "6386affe6b43cc23"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 5. Dictionary utilities\n",
    "\n",
    "`merge_dicts` performs optional recursive merges. `flatten_dict` and\n",
    "`unflatten_dict` convert between nested and dotted-key representations—useful\n",
    "for logging or CLI overrides."
   ],
   "id": "ed23ad11f6870269"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "base = {'optimizer': {'lr': 1e-3, 'beta1': 0.9}}\n",
    "override = {'optimizer': {'lr': 5e-4}, 'seed': 1234}\n",
    "merged = merge_dicts(base, override)\n",
    "print('merged:', merged)\n",
    "\n",
    "flat = flatten_dict(merged)\n",
    "print('flattened:', flat)\n",
    "unflatten_dict(flat)"
   ],
   "id": "5b53201db5c49aa8"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 6. Structured PyTrees with `util.struct`\n",
    "\n",
    "The `struct` submodule mirrors Flax-friendly data structures. The\n",
    "`dataclass` decorator registers classes as PyTrees, while `FrozenDict`\n",
    "provides immutable mappings compatible with JAX transformations."
   ],
   "id": "51c3f4b8ccd032c5"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "@struct.dataclass\n",
    "class LayerConfig:\n",
    "    weight: jax.Array\n",
    "    bias: jax.Array\n",
    "    name: str = struct.field(pytree_node=False, default='layer')\n",
    "\n",
    "cfg = LayerConfig(weight=jnp.ones((2, 2)), bias=jnp.zeros(2))\n",
    "print(cfg)\n",
    "\n",
    "cfg2 = cfg.replace(weight=jnp.full((2, 2), 3.0))\n",
    "print('updated weight:', cfg2.weight)\n",
    "\n",
    "flat_leaves, _ = jax.tree_util.tree_flatten(cfg)\n",
    "print('pytree leaves:', [leaf.shape for leaf in flat_leaves])\n",
    "\n",
    "frozen = struct.freeze({'encoder': jnp.arange(3)})\n",
    "print('frozen dict:', frozen)\n",
    "print('unfrozen:', struct.unfreeze(frozen))"
   ],
   "id": "8f930f5513453072"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 7. Filtering nested objects\n",
    "\n",
    "`brainstate.util.filter` turns declarative filters into callables. Combine tag,\n",
    "type, and path checks when traversing parameter trees."
   ],
   "id": "a3859603f03f9387"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "class Module:\n",
    "    def __init__(self, tag: str | None, kind: str):\n",
    "        self.tag = tag\n",
    "        self.kind = kind\n",
    "        self.params = jnp.arange(2)\n",
    "\n",
    "model_tree = {\n",
    "    'encoder': Module(tag='trainable', kind='linear'),\n",
    "    'decoder': Module(tag='frozen', kind='linear'),\n",
    "    'head': Module(tag='trainable', kind='mlp'),\n",
    "}\n",
    "\n",
    "tag_filter = util_filter.to_predicate('trainable')\n",
    "type_filter = util_filter.OfType(Module)\n",
    "combined = util_filter.All(type_filter, util_filter.WithTag('trainable'))\n",
    "\n",
    "def collect(tree: dict[str, Any], predicate) -> dict[str, Any]:\n",
    "    out = {}\n",
    "    for key, value in tree.items():\n",
    "        if predicate((key,), value):\n",
    "            out[key] = value\n",
    "    return out\n",
    "\n",
    "trainable_modules = collect(model_tree, tag_filter)\n",
    "both = collect(model_tree, lambda path, val: combined(path, val))\n",
    "print('trainable keys:', tuple(trainable_modules.keys()))\n",
    "print('trainable Modules:', tuple(both.keys()))"
   ],
   "id": "bf3c98ba145d1eb"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 8. Pretty PyTree containers\n",
    "\n",
    "`NestedDict`, `FlattedDict`, and `PrettyList` bring readable reprs plus PyTree\n",
    "semantics. Use them to explore checkpoints or log structured configs."
   ],
   "id": "bdec786545b99fa"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from brainstate.util import NestedDict, flat_mapping, nest_mapping, PrettyList\n",
    "\n",
    "state = NestedDict({\n",
    "    'encoder': {'weight': jnp.ones((2, 2)), 'bias': jnp.zeros(2)},\n",
    "    'decoder': {'weight': jnp.eye(2)},\n",
    "})\n",
    "print(state)\n",
    "\n",
    "flat_state = flat_mapping(state)\n",
    "print('flat keys:', list(flat_state.keys()))\n",
    "\n",
    "round_trip = nest_mapping(flat_state)\n",
    "print('round-trip equal:', round_trip == state)\n",
    "\n",
    "history = PrettyList([{'loss': 0.8}, {'loss': 0.42}])\n",
    "print(history)"
   ],
   "id": "bb5052af563dc593"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Summary\n",
    "\n",
    "- Use scheduling helpers (`split_total`, `get_unique_name`) to coordinate\n",
    "  experiments.\n",
    "- Reach for `DictManager` and `DotDict` to manage nested collections.\n",
    "- Convert between nested and flat configs with `merge_dicts`, `flatten_dict`,\n",
    "  and `unflatten_dict`.\n",
    "- Wrap structured data using `util.struct` and leverage filter/pretty utilities\n",
    "  when exploring PyTrees."
   ],
   "id": "5ada76ab8937794f"
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}
