{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5h4hj579xfl",
   "source": "# Hidden State Management\n\nIn recurrent neural networks and spiking neural networks, **hidden states** are the recurrent state variables that carry information across time steps. Examples include membrane potentials, adaptation currents, and synaptic conductances.\n\nbraintrace's compiler **automatically discovers** hidden states in your model by tracing the JAX intermediate representation (Jaxpr). It identifies which state variables are both read and written during a forward pass, then groups related hidden states into **hidden groups** for efficient Jacobian computation during online learning.\n\nThis tutorial covers:\n\n1. The three hidden state types provided by `brainstate`\n2. How the compiler discovers and groups hidden states\n3. State initialization and batching\n4. How hidden states interact with online learning algorithms",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "id": "b7biugn65oc",
   "source": "import jax\nimport jax.numpy as jnp\nimport brainstate\nimport braintrace",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "1djmk4705si",
   "source": "## Hidden State Types\n\n`brainstate` provides three hidden state classes, each suited to different model architectures:\n\n| Type | Use Case | Example |\n|------|----------|---------|\n| `brainstate.HiddenState` | Single state variable | Membrane potential of a LIF neuron |\n| `brainstate.HiddenGroupState` | Multiple correlated states with the same shape | Voltage *V* and adaptation current *I* in an adaptive neuron |\n| `brainstate.HiddenTreeState` | Hierarchical / heterogeneous state structures | LSTM cell state and hidden state, or a dict of named states |\n\nAll three are subclasses of `brainstate.HiddenState`. The compiler treats them uniformly when discovering recurrent dependencies -- you choose the one that best matches your model's structure.",
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "id": "8t2k8qttxub",
   "source": "### `HiddenState`: Single State Variable\n\n`brainstate.HiddenState` manages exactly one state tensor. This is the simplest and most common case -- use it when your neuron or synapse has a single recurrent variable (e.g., membrane potential).",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "id": "5nbmdy6vxls",
   "source": "class SimpleNeuron(brainstate.nn.Module):\n    \"\"\"A minimal recurrent neuron with a single hidden state.\"\"\"\n\n    def __init__(self, size):\n        super().__init__()\n        self.w = brainstate.ParamState(brainstate.random.randn(size, size) * 0.01)\n        self.h = brainstate.HiddenState(jnp.zeros(size))\n\n    def update(self, x):\n        # braintrace.matmul marks w as participating in online learning\n        self.h.value = jax.nn.tanh(x + braintrace.matmul(self.h.value, self.w.value))\n        return self.h.value\n\n\n# Create the model and initialize states\nmodel_simple = SimpleNeuron(32)\nbrainstate.nn.init_all_states(model_simple)\n\nprint(f\"Hidden state shape: {model_simple.h.value.shape}\")\nprint(f\"Number of state dimensions: {model_simple.h.num_state}\")",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "6fr14i7akov",
   "source": "### `HiddenGroupState`: Multiple Correlated States\n\nWhen a neuron has multiple state variables that are correlated and share the same shape, use `brainstate.HiddenGroupState`. This tells the compiler that these states form a single group -- their Jacobians should be computed together.\n\nA common example is an adaptive neuron with both a membrane voltage *V* and an adaptation current *I*.",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "id": "efoif7n1wa",
   "source": "class AdaptiveNeuron(brainstate.nn.Module):\n    \"\"\"A neuron with two correlated hidden states: voltage V and adaptation current I.\"\"\"\n\n    def __init__(self, size):\n        super().__init__()\n        self.w = brainstate.ParamState(brainstate.random.randn(size, size) * 0.01)\n        # Two correlated states bundled into one HiddenGroupState\n        self.state = brainstate.HiddenGroupState(\n            V=jnp.zeros(size),\n            I=jnp.zeros(size),\n        )\n\n    def update(self, x):\n        V, I = self.state['V'], self.state['I']\n        new_V = 0.9 * V + x + braintrace.matmul(V, self.w.value) - I\n        new_I = 0.95 * I + 0.1 * V\n        self.state.value = dict(V=new_V, I=new_I)\n        return new_V\n\n\nmodel_adaptive = AdaptiveNeuron(32)\nbrainstate.nn.init_all_states(model_adaptive)\n\nprint(f\"Number of states in group: {model_adaptive.state.num_state}\")\nprint(f\"State shape: {model_adaptive.state.varshape}\")",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "yo3v5c5b08c",
   "source": "### `HiddenTreeState`: Hierarchical State Structures\n\n`brainstate.HiddenTreeState` supports arbitrary PyTree structures (dicts, lists, nested containers). Use it when your model has many state variables that you want to organize hierarchically, or when different states have different shapes.\n\nFor instance, a GIF (Generalized Integrate-and-Fire) neuron has four state variables: two adaptation currents $I_1$, $I_2$, a membrane potential $V$, and a dynamic threshold $V_{th}$.",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "id": "7pwbcriumje",
   "source": "class TreeNeuron(brainstate.nn.Module):\n    \"\"\"A neuron using HiddenTreeState for hierarchical state management.\"\"\"\n\n    def __init__(self, size):\n        super().__init__()\n        self.w = brainstate.ParamState(brainstate.random.randn(size, size) * 0.01)\n        # Four state variables organized in a dict tree\n        self.state = brainstate.HiddenTreeState({\n            'I1': jnp.zeros(size),\n            'I2': jnp.zeros(size),\n            'V': jnp.zeros(size),\n            'V_th': jnp.ones(size),\n        })\n\n    def update(self, x):\n        I1 = self.state['I1']\n        I2 = self.state['I2']\n        V = self.state['V']\n        V_th = self.state['V_th']\n\n        new_I1 = 0.9 * I1\n        new_I2 = 0.95 * I2\n        new_V = 0.8 * V + x + braintrace.matmul(V, self.w.value) + I1 + I2\n        new_V_th = 0.99 * V_th + 0.01 * V\n\n        self.state.value = dict(I1=new_I1, I2=new_I2, V=new_V, V_th=new_V_th)\n        return new_V\n\n\nmodel_tree = TreeNeuron(32)\nbrainstate.nn.init_all_states(model_tree)\n\nprint(f\"Number of independent states in tree: {model_tree.state.num_state}\")",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "jnr84irdsnq",
   "source": "## How the Compiler Discovers Hidden States\n\nWhen you compile a model for online learning, braintrace performs the following steps:\n\n1. **Trace the Jaxpr**: The model's `update` method is traced to produce a JAX intermediate representation (Jaxpr).\n2. **Identify recurrent states**: The compiler finds state variables that appear as both inputs (read) and outputs (written) in the Jaxpr -- these are the hidden states.\n3. **Group by data flow**: States that are connected through data flow dependencies are placed into the same **hidden group**. Each group gets its own transition Jaxpr for computing the hidden-to-hidden Jacobian $\\frac{\\partial h^t}{\\partial h^{t-1}}$.\n\nYou can inspect the discovered hidden groups using `braintrace.find_hidden_groups_from_module()`.",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "id": "dxorizb39zp",
   "source": "# Inspect hidden groups for the SimpleNeuron model\nmodel_simple = SimpleNeuron(32)\nbrainstate.nn.init_all_states(model_simple)\n\ngroups, path_map = braintrace.find_hidden_groups_from_module(model_simple, jnp.zeros(32))\n\nfor g in groups:\n    print(f\"Group {g.index}:\")\n    print(f\"  Hidden state paths: {g.hidden_paths}\")\n    print(f\"  Number of states:   {g.num_state}\")\n    print(f\"  State shape:        {g.varshape}\")\n    print()",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "skqcc7kxo29",
   "source": "For the `AdaptiveNeuron` with `HiddenGroupState`, the compiler groups *V* and *I* together because they are correlated through the update equations:",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "id": "iu6bggu8lyb",
   "source": "# Inspect hidden groups for the AdaptiveNeuron model\nmodel_adaptive = AdaptiveNeuron(32)\nbrainstate.nn.init_all_states(model_adaptive)\n\ngroups, path_map = braintrace.find_hidden_groups_from_module(model_adaptive, jnp.zeros(32))\n\nfor g in groups:\n    print(f\"Group {g.index}:\")\n    print(f\"  Hidden state paths: {g.hidden_paths}\")\n    print(f\"  Number of states:   {g.num_state}\")\n    print(f\"  State shape:        {g.varshape}\")\n    print()",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "5bokhmcta94",
   "source": "## State Initialization and Reset\n\nbraintrace relies on `brainstate.nn.init_all_states()` to initialize all hidden states in a model. There are two main approaches:\n\n- **Single-sample initialization**: `brainstate.nn.init_all_states(model)` -- state tensors have shape `(M,)`.\n- **Batched initialization**: `brainstate.nn.init_all_states(model, batch_size=N)` -- state tensors have shape `(N, M)`, where `N` is the batch size. This is used for manual batching.\n\nFor automatic batching with `vmap`, you can use `brainstate.transform.vmap_new_states` to initialize per-sample states while keeping the model definition simple.",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "id": "e0e6zy4v4aw",
   "source": "model = SimpleNeuron(32)\n\n# --- Single-sample initialization ---\nbrainstate.nn.init_all_states(model)\nprint(\"Single-sample h shape:\", model.h.value.shape)\n\n# --- Batched initialization (manual batching) ---\nbrainstate.nn.init_all_states(model, batch_size=16)\nprint(\"Batched h shape:      \", model.h.value.shape)",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "id": "raeddpwnz5o",
   "source": "# --- Automatic batching with vmap_new_states ---\nmodel = SimpleNeuron(32)\n\n@brainstate.transform.vmap_new_states(state_tag='new', axis_size=16)\ndef init():\n    brainstate.nn.init_all_states(model)\n\ninit()\n\n# After vmap initialization, hidden states are managed per-sample internally.\n# The model still \"thinks\" it processes a single sample, but vmap replicates\n# the computation across the batch dimension automatically.\nprint(\"After vmap_new_states, model is ready for automatic batching.\")",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "ze96hximbu",
   "source": "## Hidden States in Online Learning\n\nDuring online learning, the algorithm needs to track how hidden states evolve over time. Specifically, it computes:\n\n- **Hidden-to-hidden Jacobians** $\\frac{\\partial h^t}{\\partial h^{t-1}}$: How the current hidden state depends on the previous one. These drive the propagation of eligibility traces.\n- **Weight spatial gradients** $\\frac{\\partial h^t}{\\partial w}$: How the hidden state depends on each weight parameter.\n\nThe diagonal approximation of the hidden-to-hidden Jacobian makes this computation tractable for large networks. The compiler automatically extracts the transition function from the Jaxpr and computes these Jacobians.\n\nLet us see the full pipeline: define a model, compile the online learning graph, and inspect its structure.",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "id": "tw9s67jlxpr",
   "source": "# Complete example: model -> compile -> inspect graph structure\n\nmodel = SimpleNeuron(8)\nbrainstate.nn.init_all_states(model)\n\n# Wrap the model in the D-RTRL online learning algorithm\nalgo = braintrace.D_RTRL(model)\n\n# Compile the computation graph with a dummy input\nalgo.compile_graph(jnp.zeros(8))\n\n# Display the discovered graph structure:\n# - Which hidden groups were found\n# - Which weight parameters are associated with each group\nalgo.show_graph()",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "58knkps468a",
   "source": "The `show_graph()` output tells you:\n\n- **Hidden groups**: Which hidden states were discovered and how they are grouped. Each group corresponds to a set of states whose Jacobian is computed together.\n- **Weight parameters**: Which `ParamState` weights are associated with each hidden group. A weight is associated with a group if it is used through an ETP primitive (e.g., `braintrace.matmul`) and its output is shape-compatible with that group's hidden states.\n\nLet us also inspect a more complex model with multiple hidden states:",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "id": "nj9ev53cba",
   "source": "# A two-layer recurrent network to demonstrate multi-group discovery\n\nclass TwoLayerRNN(brainstate.nn.Module):\n    \"\"\"Two stacked recurrent layers, each with its own hidden state.\"\"\"\n\n    def __init__(self, in_size, hidden_size, out_size):\n        super().__init__()\n        # Layer 1\n        self.w1_in = brainstate.ParamState(brainstate.random.randn(in_size, hidden_size) * 0.01)\n        self.w1_rec = brainstate.ParamState(brainstate.random.randn(hidden_size, hidden_size) * 0.01)\n        self.h1 = brainstate.HiddenState(jnp.zeros(hidden_size))\n\n        # Layer 2\n        self.w2_in = brainstate.ParamState(brainstate.random.randn(hidden_size, out_size) * 0.01)\n        self.w2_rec = brainstate.ParamState(brainstate.random.randn(out_size, out_size) * 0.01)\n        self.h2 = brainstate.HiddenState(jnp.zeros(out_size))\n\n    def update(self, x):\n        # Layer 1: x feeds in, h1 recurs\n        self.h1.value = jax.nn.tanh(\n            x @ self.w1_in.value + braintrace.matmul(self.h1.value, self.w1_rec.value)\n        )\n        # Layer 2: h1 feeds in, h2 recurs\n        self.h2.value = jax.nn.tanh(\n            self.h1.value @ self.w2_in.value + braintrace.matmul(self.h2.value, self.w2_rec.value)\n        )\n        return self.h2.value\n\n\nmodel_2layer = TwoLayerRNN(in_size=10, hidden_size=16, out_size=8)\nbrainstate.nn.init_all_states(model_2layer)\n\nalgo_2layer = braintrace.D_RTRL(model_2layer)\nalgo_2layer.compile_graph(jnp.zeros(10))\nalgo_2layer.show_graph()",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "0e9nguauw0xg",
   "source": "Notice that the compiler automatically discovered two separate hidden groups (one for each layer) and correctly associated each recurrent weight with its corresponding group. The feedforward weights (`w1_in`, `w2_in`) do not appear because they use regular JAX `@` rather than `braintrace.matmul`, so they are excluded from eligibility trace propagation.\n\nThis is a key design principle: **the operation choice controls which parameters participate in online learning**, not the parameter class. Use `braintrace.matmul(h, w)` to include a weight, and `h @ w` (standard JAX) to exclude it.",
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "id": "ph8dz1ta9i",
   "source": "## Summary\n\nThis tutorial covered the three hidden state types in braintrace and how the compiler uses them:\n\n- **`brainstate.HiddenState`** -- for a single recurrent state variable (e.g., membrane potential). The simplest and most common choice.\n- **`brainstate.HiddenGroupState`** -- for multiple correlated states with the same shape (e.g., voltage and adaptation current). The compiler treats them as a single group.\n- **`brainstate.HiddenTreeState`** -- for hierarchical or heterogeneous state structures (e.g., dicts of named states). Supports arbitrary PyTree layouts.\n\nKey takeaways:\n\n1. **Automatic discovery**: The compiler traces the model's Jaxpr and automatically identifies which states are recurrent. No manual annotation of hidden states is needed -- just use `brainstate`'s state classes.\n2. **Grouping**: Related hidden states are grouped together for efficient Jacobian computation. `HiddenGroupState` explicitly declares a group; separate `HiddenState` variables are grouped by data flow analysis.\n3. **Operation-based selection**: Whether a weight participates in online learning depends on the operation used (`braintrace.matmul` vs. regular `@`), not on the parameter class.\n4. **Flexible initialization**: Use `init_all_states` for single-sample or manual batching, and `vmap_new_states` for automatic batching.",
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}