{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Overview\n",
    "\n",
    "`brainpy.state` introduces a modern, state-based architecture built on top of `brainstate`. This overview will help you understand the key concepts and design philosophy.\n",
    "\n",
    "## What's New\n",
    "\n",
    "`brainpy.state` has been completely rewritten to provide:\n",
    "\n",
    "- **State-based programming**: Built on `brainstate` for efficient state management\n",
    "- **Modular architecture**: Clear separation of concerns (communication, dynamics, outputs)\n",
    "- **Physical units**: Integration with `saiunit` for scientifically accurate simulations\n",
    "- **Modern API**: Cleaner, more intuitive interfaces\n",
    "- **Better performance**: Optimized JIT compilation and memory management\n",
    "\n",
    "## Key Architectural Components\n",
    "\n",
    "`brainpy.state` is organized around several core concepts:\n",
    "\n",
    "### 1. State Management\n",
    "\n",
    "Everything in `brainpy.state` revolves around **states**. States are variables that persist across time steps:\n",
    "\n",
    "- `brainstate.State`: Base state container\n",
    "- `brainstate.ParamState`: Trainable parameters\n",
    "- `brainstate.ShortTermState`: Temporary variables\n",
    "\n",
    "States enable:\n",
    "\n",
    "- Automatic differentiation for training\n",
    "- Efficient memory management\n",
    "- Batching and parallelization\n",
    "\n",
    "### 2. Neurons\n",
    "\n",
    "Neurons are the fundamental computational units:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:33.064624Z",
     "iopub.status.busy": "2026-05-11T06:19:33.064445Z",
     "iopub.status.idle": "2026-05-11T06:19:33.527078Z",
     "shell.execute_reply": "2026-05-11T06:19:33.525789Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import brainpy.state\n",
    "import saiunit as u\n",
    "\n",
    "# Create a population of 100 LIF neurons\n",
    "neurons = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Key neuron models:\n",
    "\n",
    "- `brainpy.state.IF`: Integrate-and-Fire\n",
    "- `brainpy.state.LIF`: Leaky Integrate-and-Fire\n",
    "- `brainpy.state.LIFRef`: LIF with refractory period\n",
    "- `brainpy.state.ALIF`: Adaptive LIF\n",
    "\n",
    "### 3. Synapses\n",
    "\n",
    "Synapses model the dynamics of neural connections:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:33.530642Z",
     "iopub.status.busy": "2026-05-11T06:19:33.530221Z",
     "iopub.status.idle": "2026-05-11T06:19:33.534503Z",
     "shell.execute_reply": "2026-05-11T06:19:33.533314Z"
    }
   },
   "outputs": [],
   "source": [
    "# Exponential synapse\n",
    "synapse = brainpy.state.Expon(100, tau=5*u.ms)\n",
    "\n",
    "# Alpha synapse (more realistic)\n",
    "synapse = brainpy.state.Alpha(100, tau=5*u.ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Synapse models:\n",
    "\n",
    "- `brainpy.state.Expon`: Single exponential decay\n",
    "- `brainpy.state.Alpha`: Double exponential (alpha function)\n",
    "- `brainpy.state.AMPA`: Excitatory receptor dynamics\n",
    "- `brainpy.state.GABAa`: Inhibitory receptor dynamics\n",
    "\n",
    "### 4. Projections\n",
    "\n",
    "Projections connect neural populations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:33.536943Z",
     "iopub.status.busy": "2026-05-11T06:19:33.536720Z",
     "iopub.status.idle": "2026-05-11T06:19:33.592964Z",
     "shell.execute_reply": "2026-05-11T06:19:33.591894Z"
    }
   },
   "outputs": [],
   "source": [
    "import brainstate\n",
    "\n",
    "N_pre=100\n",
    "N_post=50\n",
    "prob=0.1\n",
    "weight=0.5\n",
    "\n",
    "projection = brainpy.state.AlignPostProj(\n",
    "    comm=brainstate.nn.EventFixedProb(N_pre, N_post, prob, weight),\n",
    "    syn=brainpy.state.Expon.desc(N_post, tau=5*u.ms),\n",
    "    out=brainpy.state.CUBA.desc(),\n",
    "    post=neurons\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The projection architecture separates:\n",
    "\n",
    "- **Communication**: How spikes are transmitted (connectivity, weights)\n",
    "- **Synaptic dynamics**: How synapses respond (temporal filtering)\n",
    "- **Output mechanism**: How synaptic currents affect neurons (CUBA/COBA)\n",
    "\n",
    "### 5. Networks\n",
    "\n",
    "Networks combine neurons and projections:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:33.596108Z",
     "iopub.status.busy": "2026-05-11T06:19:33.595703Z",
     "iopub.status.idle": "2026-05-11T06:19:33.602776Z",
     "shell.execute_reply": "2026-05-11T06:19:33.601827Z"
    }
   },
   "outputs": [],
   "source": [
    "class EINet(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.E = brainpy.state.LIF(800)\n",
    "        self.I = brainpy.state.LIF(200)\n",
    "        self.E2E = brainpy.state.AlignPostProj(...)\n",
    "        self.E2I = brainpy.state.AlignPostProj(...)\n",
    "        # ... more projections\n",
    "\n",
    "    def update(self, input):\n",
    "        # Define network dynamics\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computational Model\n",
    "\n",
    "### Time-Stepped Simulation\n",
    "\n",
    "BrainPy uses discrete time steps for simulation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:33.605969Z",
     "iopub.status.busy": "2026-05-11T06:19:33.605630Z",
     "iopub.status.idle": "2026-05-11T06:19:33.923415Z",
     "shell.execute_reply": "2026-05-11T06:19:33.922398Z"
    }
   },
   "outputs": [],
   "source": [
    "# Set simulation time step\n",
    "brainstate.environ.set(dt=0.1 * u.ms)\n",
    "\n",
    "# Create a simple neuron for demonstration\n",
    "neurons = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)\n",
    "\n",
    "# Initialize all states\n",
    "brainstate.nn.init_all_states(neurons)\n",
    "\n",
    "# Run simulation\n",
    "def step(t, i):\n",
    "    with brainstate.environ.context(t=t, i=i):\n",
    "        # Provide input current to the neurons\n",
    "        neurons.update(5 * u.nA)\n",
    "        return neurons.get_spike()\n",
    "\n",
    "times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())\n",
    "indices = u.math.arange(times.size)\n",
    "results = brainstate.transform.for_loop(step, times, indices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### JIT Compilation\n",
    "\n",
    "BrainPy leverages JAX for Just-In-Time compilation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:33.926288Z",
     "iopub.status.busy": "2026-05-11T06:19:33.925957Z",
     "iopub.status.idle": "2026-05-11T06:19:34.097370Z",
     "shell.execute_reply": "2026-05-11T06:19:34.096507Z"
    }
   },
   "outputs": [],
   "source": [
    "# Create a simple network for demonstration\n",
    "network = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)\n",
    "brainstate.nn.init_all_states(network)\n",
    "\n",
    "# Define input current\n",
    "input_current = 5 * u.nA\n",
    "\n",
    "# JIT-compiled simulation function\n",
    "@brainstate.transform.jit\n",
    "def simulate(t, i):\n",
    "    with brainstate.environ.context(t=t, i=i):\n",
    "        network.update(input_current)\n",
    "        return network.get_spike()\n",
    "\n",
    "# First call compiles, subsequent calls are fast\n",
    "times = u.math.arange(0*u.ms, 100*u.ms, brainstate.environ.get_dt())\n",
    "indices = u.math.arange(times.size)\n",
    "result = brainstate.transform.for_loop(simulate, times, indices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Benefits:\n",
    "\n",
    "- Near-C performance\n",
    "- Automatic GPU/TPU dispatch\n",
    "- Optimized memory usage\n",
    "\n",
    "### Physical Units\n",
    "\n",
    "``brainpy.state`` integrates `saiunit` for scientific accuracy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:34.099901Z",
     "iopub.status.busy": "2026-05-11T06:19:34.099659Z",
     "iopub.status.idle": "2026-05-11T06:19:34.103778Z",
     "shell.execute_reply": "2026-05-11T06:19:34.103017Z"
    }
   },
   "outputs": [],
   "source": [
    "import saiunit as u\n",
    "\n",
    "# Define parameters with units\n",
    "tau = 10 * u.ms\n",
    "V_threshold = -50 * u.mV\n",
    "current = 5 * u.nA\n",
    "\n",
    "# Units are checked automatically\n",
    "neurons = brainpy.state.LIF(100, tau=tau, V_th=V_threshold)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This prevents unit-related bugs and makes code self-documenting.\n",
    "\n",
    "## Training and Learning\n",
    "\n",
    "``brainpy.state``  supports gradient-based training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:34.106015Z",
     "iopub.status.busy": "2026-05-11T06:19:34.105788Z",
     "iopub.status.idle": "2026-05-11T06:19:34.269475Z",
     "shell.execute_reply": "2026-05-11T06:19:34.268521Z"
    }
   },
   "outputs": [],
   "source": [
    "import braintools\n",
    "\n",
    "# Create a simple network for training\n",
    "net = brainpy.state.LIF(10, tau=10*u.ms, V_th=-50*u.mV)\n",
    "brainstate.nn.init_all_states(net)\n",
    "\n",
    "# Define optimizer\n",
    "optimizer = braintools.optim.Adam(lr=1e-3)\n",
    "optimizer.register_trainable_weights(net.states(brainstate.ParamState))\n",
    "\n",
    "# Prepare dummy data for demonstration\n",
    "num_steps = 100\n",
    "inputs = u.math.ones((num_steps,)) * 5 * u.nA\n",
    "targets = u.math.zeros((num_steps, 10))  # dummy target\n",
    "\n",
    "# Define loss function\n",
    "def loss_fn():\n",
    "    def step(t, i, inp):\n",
    "        with brainstate.environ.context(t=t, i=i):\n",
    "            net.update(inp)\n",
    "            return net.spike.value\n",
    "    \n",
    "    times = u.math.arange(0*u.ms, num_steps*brainstate.environ.get_dt(), brainstate.environ.get_dt())\n",
    "    indices = u.math.arange(times.size)\n",
    "    predictions = brainstate.transform.for_loop(step, times, indices, inputs)\n",
    "    # Simple MSE loss\n",
    "    return u.math.mean((predictions.astype(float) - targets) ** 2)\n",
    "\n",
    "# Training step\n",
    "@brainstate.transform.jit\n",
    "def train_step():\n",
    "    grads, loss_value = brainstate.transform.grad(\n",
    "        loss_fn,\n",
    "        net.states(brainstate.ParamState),\n",
    "        return_value=True\n",
    "    )()\n",
    "    optimizer.update(grads)\n",
    "    return loss_value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Key features:\n",
    "\n",
    "- Surrogate gradients for spiking neurons\n",
    "- Automatic differentiation\n",
    "- Various optimizers (Adam, SGD, etc.)\n",
    "\n",
    "## Ecosystem Components\n",
    "\n",
    "`brainpy.state` is part of a larger ecosystem:\n",
    "\n",
    "### brainstate\n",
    "\n",
    "The foundation for state management and compilation:\n",
    "\n",
    "- State-based IR construction\n",
    "- JIT compilation\n",
    "- Program augmentation (batching, etc.)\n",
    "\n",
    "### saiunit\n",
    "\n",
    "Physical units system:\n",
    "\n",
    "- SI units support\n",
    "- Automatic unit checking\n",
    "- Unit conversions\n",
    "\n",
    "### braintools\n",
    "\n",
    "Utilities and tools:\n",
    "\n",
    "- Optimizers (`braintools.optim`)\n",
    "- Initialization (`braintools.init`)\n",
    "- Metrics and losses (`braintools.metric`)\n",
    "- Surrogate gradients (`braintools.surrogate`)\n",
    "- Visualization (`braintools.visualize`)\n",
    "\n",
    "## Design Philosophy\n",
    "\n",
    "`brainpy.state` follows these principles:\n",
    "\n",
    "1. **Explicit over implicit**: Clear, readable code\n",
    "2. **Modular composition**: Build complex models from simple components\n",
    "3. **Performance by default**: JIT compilation and optimization built-in\n",
    "4. **Scientific accuracy**: Physical units and biologically realistic models\n",
    "5. **Extensibility**: Easy to add custom components\n",
    "\n",
    "## Next Steps\n",
    "\n",
    "Now that you understand the core concepts:\n",
    "\n",
    "- Try the [5-minute tutorial](5min-tutorial.ipynb) to get hands-on experience\n",
    "- Read the detailed [BrainPy-style modeling guide](../brainpy-guide/index.rst)\n",
    "- Explore the examples in the repository to learn each component\n",
    "- Check out the [examples gallery](../examples/gallery.rst) for real-world models"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
