{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a8c16f8a",
   "metadata": {},
   "source": [
    "# The mental model\n",
    "\n",
    "**Who it's for:** everyone, before you go deeper. The four ideas below are the whole framework in miniature — they apply identically whether you simulate biophysical networks or train spiking networks with gradients.\n",
    "\n",
    "**What you'll learn:** (1) state-based programming, (2) physical units, (3) how neurons, synapses, and projections compose, and (4) how to drive a model with `brainstate.transform`. We close with the \"two worlds, one substrate\" idea that organizes the rest of the docs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "325b9a94",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:51.187852Z",
     "iopub.status.busy": "2026-06-17T09:11:51.187680Z",
     "iopub.status.idle": "2026-06-17T09:11:55.342654Z",
     "shell.execute_reply": "2026-06-17T09:11:55.341581Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import brainpy\n",
    "import brainstate\n",
    "import braintools\n",
    "import brainunit as u"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e01e4ef",
   "metadata": {},
   "source": [
    "## Idea 1 — State-based programming\n",
    "\n",
    "JAX is *functional*: transformations like JIT and autodiff want pure functions with no hidden mutable variables. But a neuron's membrane potential is mutable state that persists across time steps. `brainstate` resolves this tension by making every such variable an explicit **`State`**, owned by a module.\n",
    "\n",
    "You rarely create raw states by hand — models declare their own. Your job is to **construct** the model and then **initialize** its states with `brainstate.nn.init_all_states`, which allocates and resets every dynamic variable to a clean starting point."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "373b2188",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:55.345418Z",
     "iopub.status.busy": "2026-06-17T09:11:55.344911Z",
     "iopub.status.idle": "2026-06-17T09:11:58.153333Z",
     "shell.execute_reply": "2026-06-17T09:11:58.152328Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LIF(\n",
       "  in_size=(100,),\n",
       "  out_size=(100,),\n",
       "  spk_reset=soft,\n",
       "  spk_fun=ReluGrad(alpha=0.3, width=1.0),\n",
       "  R=Quantity(1., \"ohm\"),\n",
       "  tau=Quantity(10., \"ms\"),\n",
       "  V_th=Quantity(-50., \"mV\"),\n",
       "  V_rest=Quantity(0., \"mV\"),\n",
       "  V_reset=Quantity(0., \"mV\"),\n",
       "  V_initializer=Constant(value=0. mV),\n",
       "  V=HiddenState(\n",
       "    value=Quantity(~float32[100], \"mV\")\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "neuron = brainpy.state.LIF(100, tau=10. * u.ms, V_th=-50. * u.mV)\n",
    "brainstate.nn.init_all_states(neuron)   # allocate + reset its States"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b18f1397",
   "metadata": {},
   "source": [
    "Time-dependent context (the step size `dt`, the current time `t`) is supplied with `brainstate.environ.context(...)`, so models read it without you threading it through every call:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a4ef65ca",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:58.156518Z",
     "iopub.status.busy": "2026-06-17T09:11:58.155928Z",
     "iopub.status.idle": "2026-06-17T09:11:58.161296Z",
     "shell.execute_reply": "2026-06-17T09:11:58.160570Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dt = 0.1 ms\n"
     ]
    }
   ],
   "source": [
    "with brainstate.environ.context(dt=0.1 * u.ms):\n",
    "    print('dt =', brainstate.environ.get_dt())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa16e64e",
   "metadata": {},
   "source": [
    "## Idea 2 — Physical units\n",
    "\n",
    "Parameters carry real physical units via `brainunit`. A time constant is *milliseconds*, a threshold is *millivolts*, a current is *milliamps*. Units are checked at construction time, so a `ms`-vs-`s` or `mV`-vs-`V` slip is caught immediately instead of silently corrupting a run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dcb09daf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:58.163314Z",
     "iopub.status.busy": "2026-06-17T09:11:58.162953Z",
     "iopub.status.idle": "2026-06-17T09:11:58.167259Z",
     "shell.execute_reply": "2026-06-17T09:11:58.166309Z"
    }
   },
   "outputs": [],
   "source": [
    "tau = 10. * u.ms          # membrane time constant\n",
    "V_threshold = -50. * u.mV  # spike threshold\n",
    "current = 20. * u.mA       # input current\n",
    "\n",
    "# Units flow straight into model construction.\n",
    "neuron = brainpy.state.LIF(100, tau=tau, V_th=V_threshold)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb2040c2",
   "metadata": {},
   "source": [
    "## Idea 3 — Compose neurons + synapses + projections\n",
    "\n",
    "Networks are built by composition. The three building blocks:\n",
    "\n",
    "- **Neurons** (e.g. `LIF`, `LIFRef`, `ALIF`, `HH`) hold membrane state and emit spikes.\n",
    "- **Synapses** (e.g. `Expon`, `Alpha`, `AMPA`) filter incoming spikes into currents/conductances over time.\n",
    "- **Projections** wire populations together. A projection separates four roles: `comm` (the connection matrix / connectivity), `syn` (the synapse dynamics), `out` (how it drives the target — `COBA`/`CUBA`), and `post` (the target population).\n",
    "\n",
    "Here two populations are joined by a single `AlignPostProj`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4828374c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:58.169442Z",
     "iopub.status.busy": "2026-06-17T09:11:58.169213Z",
     "iopub.status.idle": "2026-06-17T09:11:58.174835Z",
     "shell.execute_reply": "2026-06-17T09:11:58.174052Z"
    }
   },
   "outputs": [],
   "source": [
    "pre = brainpy.state.LIF(100, tau=10. * u.ms, V_th=-50. * u.mV)\n",
    "post = brainpy.state.LIF(50, tau=10. * u.ms, V_th=-50. * u.mV)\n",
    "\n",
    "proj = brainpy.state.AlignPostProj(\n",
    "    comm=brainstate.nn.EventFixedProb(100, 50, conn_num=0.1, conn_weight=0.5 * u.mS),\n",
    "    syn=brainpy.state.Expon.desc(50, tau=5. * u.ms),\n",
    "    out=brainpy.state.COBA.desc(E=0. * u.mV),\n",
    "    post=post,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6df82878",
   "metadata": {},
   "source": [
    "Which projection to reach for — and why aligning synaptic state to the *post* (or *pre*) population keeps memory linear — is the subject of the keystone chapter, {doc}`/concepts/alignpre-alignpost`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66ee4420",
   "metadata": {},
   "source": [
    "## Idea 4 — Drive with `brainstate.transform`\n",
    "\n",
    "A model runs over many time steps. **Never** drive it with a bare Python `for`/`while` loop: that executes op-by-op, re-traces every step, and forfeits fusion. Instead lower the whole loop into one compiled program with a `brainstate.transform` primitive:\n",
    "\n",
    "- **single step / one-shot** → `brainstate.transform.jit`\n",
    "- **many steps, collect outputs** → `brainstate.transform.for_loop`\n",
    "- **many steps with an explicit carry** → `brainstate.transform.scan`\n",
    "- **long rollout under autograd (BPTT)** → `brainstate.transform.checkpointed_for_loop` / `checkpointed_scan`\n",
    "\n",
    "Running the neuron from Idea 1–2 for 200 ms of constant input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "707ec574",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:58.176578Z",
     "iopub.status.busy": "2026-06-17T09:11:58.176365Z",
     "iopub.status.idle": "2026-06-17T09:11:58.391313Z",
     "shell.execute_reply": "2026-06-17T09:11:58.390349Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "spikes shape: (2000, 100)\n"
     ]
    }
   ],
   "source": [
    "brainstate.nn.init_all_states(neuron)\n",
    "\n",
    "def step(t):\n",
    "    with brainstate.environ.context(t=t):\n",
    "        neuron(current)\n",
    "        return neuron.get_spike()\n",
    "\n",
    "with brainstate.environ.context(dt=0.1 * u.ms):\n",
    "    times = u.math.arange(0. * u.ms, 200. * u.ms, brainstate.environ.get_dt())\n",
    "    spikes = brainstate.transform.for_loop(step, times)\n",
    "\n",
    "print('spikes shape:', spikes.shape)   # [time, neuron]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f2c7300",
   "metadata": {},
   "source": [
    "## Two worlds, one substrate\n",
    "\n",
    "Everything above is shared. The *same* state-based models, the *same* units, the *same* `transform`-driven loops power both:\n",
    "\n",
    "- **Brain simulation** — run biophysical E/I networks and analyze their dynamics (the {doc}`5-minute tour </get-started/5-minute-tour>` you may have just seen).\n",
    "- **Brain-inspired computing** — because the models are differentiable (neurons accept a surrogate `spk_fun`, and `for_loop`/`scan` are differentiable), you train them with gradients and scale them with linear-memory online learning.\n",
    "\n",
    "The hinge between the two worlds is the **AlignPre/AlignPost** projection design: the same alignment that makes simulation memory-efficient is what makes gradient-based and online learning memory-efficient. That is why it is the keystone of the concept spine."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0b443d5",
   "metadata": {},
   "source": [
    "## See also\n",
    "\n",
    "- {doc}`/concepts/alignpre-alignpost` — the keystone chapter the four ideas build toward.\n",
    "- {doc}`/concepts/index` — the full Core Concepts spine and a recommended reading order.\n",
    "- {doc}`/get-started/5-minute-tour` — see these ideas at work in a complete, runnable E/I network."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
