{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Overview\n",
    "\n",
    "``brainpy.state`` represents a complete architectural redesign built on top of the ``brainstate`` framework. This document explains the design principles and architectural components that make ``brainpy.state`` powerful and flexible."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Design Philosophy\n",
    "\n",
    "``brainpy.state``  is built around several core principles:\n",
    "\n",
    "**State-Based Programming**\n",
    "   All dynamical variables are managed as explicit states, enabling automatic differentiation, efficient compilation, and clear data flow.\n",
    "\n",
    "**Modular Composition**\n",
    "   Complex models are built by composing simple, reusable components. Each component has a well-defined interface and responsibility.\n",
    "\n",
    "**Scientific Accuracy**\n",
    "   Integration with ``saiunit`` ensures physical correctness and prevents unit-related errors.\n",
    "\n",
    "**Performance by Default**\n",
    "   JIT compilation and optimization are built into the framework, not an afterthought.\n",
    "\n",
    "**Extensibility**\n",
    "   Adding new neuron models, synapse types, or learning rules is straightforward and follows clear patterns."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Architectural Layers\n",
    "\n",
    "brainpy.state  is organized into several layers:\n",
    "\n",
    "```text\n",
    "┌─────────────────────────────────────────┐\n",
    "│         User Models & Networks          │  ← Your code\n",
    "├─────────────────────────────────────────┤\n",
    "│      BrainPy Components Layer           │  ← Neurons, Synapses, Projections\n",
    "├─────────────────────────────────────────┤\n",
    "│       BrainState Framework              │  ← State management, compilation\n",
    "├─────────────────────────────────────────┤\n",
    "│       JAX + XLA Backend                 │  ← JIT compilation, autodiff\n",
    "└─────────────────────────────────────────┘\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. JAX + XLA Backend\n",
    "\n",
    "The foundation layer provides:\n",
    "\n",
    "- Just-In-Time (JIT) compilation\n",
    "- Automatic differentiation\n",
    "- Hardware acceleration (CPU/GPU/TPU)\n",
    "- Functional transformations (vmap, grad, etc.)\n",
    "\n",
    "### 2. BrainState Framework\n",
    "\n",
    "Built on JAX, ``brainstate`` provides:\n",
    "\n",
    "- State management system\n",
    "- Module composition\n",
    "- Compilation and optimization\n",
    "- Program transformations (for_loop, etc.)\n",
    "\n",
    "### 3. BrainPy Components\n",
    "\n",
    "High-level neuroscience-specific components:\n",
    "\n",
    "- Neuron models (LIF, ALIF, etc.)\n",
    "- Synapse models (Expon, Alpha, etc.)\n",
    "- Projection architectures\n",
    "- Learning rules and plasticity\n",
    "\n",
    "### 4. User Models\n",
    "\n",
    "Your custom networks and experiments built using BrainPy components."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## State Management System\n",
    "\n",
    "### The Foundation: ``brainstate.State``\n",
    "\n",
    "Everything in ``brainpy.state``  revolves around **states**:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.201528Z",
     "start_time": "2025-11-13T09:31:02.936126Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:33.009178Z",
     "iopub.status.busy": "2026-05-11T06:19:33.008966Z",
     "iopub.status.idle": "2026-05-11T06:19:35.657920Z",
     "shell.execute_reply": "2026-05-11T06:19:35.657134Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import brainpy.state\n",
    "import brainstate\n",
    "import braintools\n",
    "import saiunit as u\n",
    "import jax.numpy as jnp\n",
    "\n",
    "# Create a state\n",
    "voltage = brainstate.State(0.0)  # Single value\n",
    "weights = brainstate.State([[0.1, 0.2], [0.3, 0.4]])  # Matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "States are special containers that:\n",
    "\n",
    "- Track their values across time\n",
    "- Support automatic differentiation\n",
    "- Enable efficient compilation\n",
    "- Handle batching automatically"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### State Types\n",
    "\n",
    "BrainPy uses different state types for different purposes:\n",
    "\n",
    "**ParamState** - Trainable Parameters\n",
    "   Used for weights, time constants, and other trainable parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.232382Z",
     "start_time": "2025-11-13T09:31:08.227046Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:35.660619Z",
     "iopub.status.busy": "2026-05-11T06:19:35.660147Z",
     "iopub.status.idle": "2026-05-11T06:19:35.664740Z",
     "shell.execute_reply": "2026-05-11T06:19:35.663713Z"
    }
   },
   "outputs": [],
   "source": [
    "class MyNeuron(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.tau = brainstate.ParamState(10.0)  # Trainable\n",
    "        self.weight = brainstate.ParamState([[0.1, 0.2]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**ShortTermState** - Temporary Variables\n",
    "   Used for membrane potentials, synaptic currents, and other dynamics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.256361Z",
     "start_time": "2025-11-13T09:31:08.248156Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:35.667391Z",
     "iopub.status.busy": "2026-05-11T06:19:35.667156Z",
     "iopub.status.idle": "2026-05-11T06:19:35.670708Z",
     "shell.execute_reply": "2026-05-11T06:19:35.669836Z"
    }
   },
   "outputs": [],
   "source": [
    "class MyNeuron(brainstate.nn.Module):\n",
    "    def __init__(self, size):\n",
    "        super().__init__()\n",
    "        self.V = brainstate.ShortTermState(jnp.zeros(size))  # Dynamic\n",
    "        self.spike = brainstate.ShortTermState(jnp.zeros(size))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### State Initialization\n",
    "\n",
    "States can be initialized with various strategies:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.332155Z",
     "start_time": "2025-11-13T09:31:08.288203Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:35.673616Z",
     "iopub.status.busy": "2026-05-11T06:19:35.673335Z",
     "iopub.status.idle": "2026-05-11T06:19:36.379440Z",
     "shell.execute_reply": "2026-05-11T06:19:36.378488Z"
    }
   },
   "outputs": [],
   "source": [
    "# Define example size and shape\n",
    "size = 100  # Number of neurons\n",
    "shape = (100, 50)  # Weight matrix shape\n",
    "\n",
    "# Constant initialization\n",
    "V = brainstate.ShortTermState(\n",
    "    braintools.init.Constant(-65.0, unit=u.mV)(size)\n",
    ")\n",
    "\n",
    "# Normal distribution\n",
    "V = brainstate.ShortTermState(\n",
    "    braintools.init.Normal(-65.0, 5.0, unit=u.mV)(size)\n",
    ")\n",
    "\n",
    "# Uniform distribution\n",
    "weights = brainstate.ParamState(\n",
    "    braintools.init.Uniform(0.0, 1.0)(shape)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Module System\n",
    "\n",
    "### Base Class: brainstate.nn.Module\n",
    "\n",
    "All BrainPy components inherit from ``brainstate.nn.Module``:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.344089Z",
     "start_time": "2025-11-13T09:31:08.338163Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:36.381651Z",
     "iopub.status.busy": "2026-05-11T06:19:36.381492Z",
     "iopub.status.idle": "2026-05-11T06:19:36.385023Z",
     "shell.execute_reply": "2026-05-11T06:19:36.384323Z"
    }
   },
   "outputs": [],
   "source": [
    "class MyComponent(brainstate.nn.Module):\n",
    "    def __init__(self, size):\n",
    "        super().__init__()\n",
    "        # Initialize states\n",
    "        self.state1 = brainstate.ShortTermState(jnp.zeros(size))\n",
    "        self.param1 = brainstate.ParamState(jnp.ones(size))\n",
    "\n",
    "    def update(self, input):\n",
    "        # Define dynamics\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Benefits of Module:\n",
    "\n",
    "- Automatic state registration\n",
    "- Nested module support\n",
    "- State collection and filtering\n",
    "- Serialization support"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Module Composition\n",
    "\n",
    "Modules can contain other modules:\n",
    "\n",
    "```python\n",
    "\n",
    "class Network(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.neurons = brainpy.state.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)\n",
    "        self.synapse = brainpy.state.Expon(100, tau=5*u.ms)\n",
    "        self.projection = brainpy.state.AlignPostProj(...)  # Example - requires more setup\n",
    "\n",
    "    def update(self, input):\n",
    "        # Compose behavior\n",
    "        self.projection(spikes)  # Example\n",
    "        self.neurons(input)\n",
    "\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Component Architecture\n",
    "\n",
    "### Neurons\n",
    "\n",
    "Neurons model the dynamics of neural populations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.354960Z",
     "start_time": "2025-11-13T09:31:08.350705Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:36.387295Z",
     "iopub.status.busy": "2026-05-11T06:19:36.387033Z",
     "iopub.status.idle": "2026-05-11T06:19:36.391211Z",
     "shell.execute_reply": "2026-05-11T06:19:36.390139Z"
    }
   },
   "outputs": [],
   "source": [
    "class Neuron(brainstate.nn.Module):\n",
    "    def __init__(self, size, **kwargs):\n",
    "        super().__init__()\n",
    "        # Membrane potential\n",
    "        self.V = brainstate.ShortTermState(jnp.zeros(size))\n",
    "        # Spike output\n",
    "        self.spike = brainstate.ShortTermState(jnp.zeros(size))\n",
    "\n",
    "    def update(self, input_current):\n",
    "        # Update membrane potential\n",
    "        # Generate spikes\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Key responsibilities:\n",
    "\n",
    "- Maintain membrane potential\n",
    "- Generate spikes when threshold is crossed\n",
    "- Reset after spiking\n",
    "- Integrate input currents"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Synapses\n",
    "\n",
    "Synapses model temporal filtering of spike trains:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.365530Z",
     "start_time": "2025-11-13T09:31:08.361447Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:36.393683Z",
     "iopub.status.busy": "2026-05-11T06:19:36.393414Z",
     "iopub.status.idle": "2026-05-11T06:19:36.397456Z",
     "shell.execute_reply": "2026-05-11T06:19:36.396621Z"
    }
   },
   "outputs": [],
   "source": [
    "class Synapse(brainstate.nn.Module):\n",
    "    def __init__(self, size, tau, **kwargs):\n",
    "        super().__init__()\n",
    "        # Synaptic conductance/current\n",
    "        self.g = brainstate.ShortTermState(jnp.zeros(size))\n",
    "        self.tau = tau\n",
    "\n",
    "    def update(self, spike_input):\n",
    "        # Update synaptic variable\n",
    "        # Return filtered output\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Key responsibilities:\n",
    "\n",
    "- Filter spike inputs temporally\n",
    "- Model synaptic dynamics (exponential, alpha, etc.)\n",
    "- Provide smooth currents to postsynaptic neurons"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Projections: The Comm-Syn-Out Pattern\n",
    "\n",
    "Projections connect populations using a three-stage architecture:\n",
    "\n",
    "```text\n",
    "Presynaptic Spikes → [Comm] → [Syn] → [Out] → Postsynaptic Neurons\n",
    "                      │         │       │\n",
    "                  Connectivity  │    Current\n",
    "                  & Weights   Dynamics  Injection\n",
    "```\n",
    "\n",
    "**Communication (Comm)**\n",
    "   Handles spike transmission, connectivity, and weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.380053Z",
     "start_time": "2025-11-13T09:31:08.371055Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:36.400171Z",
     "iopub.status.busy": "2026-05-11T06:19:36.399867Z",
     "iopub.status.idle": "2026-05-11T06:19:36.428644Z",
     "shell.execute_reply": "2026-05-11T06:19:36.427687Z"
    }
   },
   "outputs": [],
   "source": [
    "# Define population sizes\n",
    "pre_size = 100\n",
    "post_size = 50\n",
    "\n",
    "# Define prob and weight\n",
    "prob=0.1\n",
    "weight=0.5\n",
    "\n",
    "comm = brainstate.nn.EventFixedProb(\n",
    "    pre_size, post_size, prob, weight\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Synaptic Dynamics (Syn)**\n",
    "   Temporal filtering of transmitted spikes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.390544Z",
     "start_time": "2025-11-13T09:31:08.386339Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:36.431009Z",
     "iopub.status.busy": "2026-05-11T06:19:36.430774Z",
     "iopub.status.idle": "2026-05-11T06:19:36.433910Z",
     "shell.execute_reply": "2026-05-11T06:19:36.433215Z"
    }
   },
   "outputs": [],
   "source": [
    "post_size = 50  # Postsynaptic population size\n",
    "\n",
    "syn = brainpy.state.Expon(post_size, tau=5*u.ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Output Mechanism (Out)**\n",
    "   How synaptic variables affect postsynaptic neurons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.400227Z",
     "start_time": "2025-11-13T09:31:08.397246Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:36.435677Z",
     "iopub.status.busy": "2026-05-11T06:19:36.435509Z",
     "iopub.status.idle": "2026-05-11T06:19:36.438939Z",
     "shell.execute_reply": "2026-05-11T06:19:36.438095Z"
    }
   },
   "outputs": [],
   "source": [
    "# Current-based output\n",
    "out = brainpy.state.CUBA()  \n",
    "\n",
    "# Or conductance-based output\n",
    "out = brainpy.state.COBA(E=0*u.mV)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Complete Projection**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.420603Z",
     "start_time": "2025-11-13T09:31:08.405533Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:36.441306Z",
     "iopub.status.busy": "2026-05-11T06:19:36.441071Z",
     "iopub.status.idle": "2026-05-11T06:19:36.444669Z",
     "shell.execute_reply": "2026-05-11T06:19:36.444005Z"
    }
   },
   "outputs": [],
   "source": [
    "# Define postsynaptic neurons\n",
    "postsynaptic_neurons = brainpy.state.LIF(50, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)\n",
    "\n",
    "# Create complete projection\n",
    "projection = brainpy.state.AlignPostProj(\n",
    "    comm=comm,\n",
    "    syn=syn,\n",
    "    out=out,\n",
    "    post=postsynaptic_neurons\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This separation provides:\n",
    "\n",
    "- Clear responsibility boundaries\n",
    "- Easy component swapping\n",
    "- Reusable building blocks\n",
    "- Better testing and debugging"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compilation and Execution\n",
    "\n",
    "### Time-Stepped Simulation\n",
    "\n",
    "BrainPy uses discrete time steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:08.937334Z",
     "start_time": "2025-11-13T09:31:08.430048Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:36.446778Z",
     "iopub.status.busy": "2026-05-11T06:19:36.446531Z",
     "iopub.status.idle": "2026-05-11T06:19:36.688565Z",
     "shell.execute_reply": "2026-05-11T06:19:36.687741Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "50db3d97379845a69132eca4e50c8879",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Example: create a simple network\n",
    "class SimpleNetwork(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.neurons = brainpy.state.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)\n",
    "    \n",
    "    def update(self, t, i):\n",
    "        # Generate constant input current\n",
    "        inp = jnp.ones(100) * 5.0 * u.nA\n",
    "        with brainstate.environ.context(t=t, i=i):\n",
    "            self.neurons(inp)\n",
    "            return self.neurons.get_spike()\n",
    "\n",
    "network = SimpleNetwork()\n",
    "brainstate.nn.init_all_states(network)\n",
    "\n",
    "# Set global time step\n",
    "brainstate.environ.set(dt=0.1 * u.ms)\n",
    "\n",
    "# Define simulation duration\n",
    "times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())\n",
    "indices = u.math.arange(times.size)\n",
    "\n",
    "# Run simulation\n",
    "results = brainstate.transform.for_loop(\n",
    "    network.update,\n",
    "    times,\n",
    "    indices,\n",
    "    pbar=brainstate.transform.ProgressBar(10)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### JIT Compilation\n",
    "\n",
    "Functions are compiled for performance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:09.104741Z",
     "start_time": "2025-11-13T09:31:08.980341Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:36.690844Z",
     "iopub.status.busy": "2026-05-11T06:19:36.690649Z",
     "iopub.status.idle": "2026-05-11T06:19:36.794900Z",
     "shell.execute_reply": "2026-05-11T06:19:36.793822Z"
    }
   },
   "outputs": [],
   "source": [
    "# Create example input\n",
    "input_example = jnp.ones(100) * 2.0 * u.nA\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def simulate_step(t, i, input_current):\n",
    "    with brainstate.environ.context(t=t, i=i):\n",
    "        return network.update(t, i)\n",
    "\n",
    "# First call: compile\n",
    "result = simulate_step(0.0*u.ms, 0, input_example)  # Slow (compilation)\n",
    "\n",
    "# Subsequent calls: fast\n",
    "result = simulate_step(0.1*u.ms, 1, input_example)  # Fast (compiled)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compilation benefits:\n",
    "\n",
    "- 10-100x speedup over Python\n",
    "- Automatic GPU/TPU dispatch\n",
    "- Memory optimization\n",
    "- Fusion of operations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gradient Computation\n",
    "\n",
    "For training, gradients are computed automatically:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:09.356118Z",
     "start_time": "2025-11-13T09:31:09.113521Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:36.797040Z",
     "iopub.status.busy": "2026-05-11T06:19:36.796824Z",
     "iopub.status.idle": "2026-05-11T06:19:37.057483Z",
     "shell.execute_reply": "2026-05-11T06:19:37.056455Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss (no trainable params): 0.0\n"
     ]
    }
   ],
   "source": [
    "# Example: Define mock functions for demonstration\n",
    "def compute_loss(predictions, targets):\n",
    "    return jnp.mean((predictions.astype(float) - targets) ** 2)\n",
    "\n",
    "# Mock targets\n",
    "num_steps = 100\n",
    "targets = jnp.zeros((num_steps, 100))\n",
    "\n",
    "def loss_fn():\n",
    "    # Run network for multiple timesteps\n",
    "    def step(t, i):\n",
    "        with brainstate.environ.context(t=t, i=i):\n",
    "            return network.update(t, i)\n",
    "    \n",
    "    times = u.math.arange(0*u.ms, num_steps*brainstate.environ.get_dt(), brainstate.environ.get_dt())\n",
    "    indices = u.math.arange(times.size)\n",
    "    predictions = brainstate.transform.for_loop(step, times, indices)\n",
    "    return compute_loss(predictions, targets)\n",
    "\n",
    "# Get trainable parameters\n",
    "params = network.states(brainstate.ParamState)\n",
    "\n",
    "# Compute gradients\n",
    "if len(params) > 0:\n",
    "    optimizer = braintools.optim.Adam(lr=1e-3)\n",
    "    grads, loss = brainstate.transform.grad(\n",
    "        loss_fn,\n",
    "        grad_states=params,\n",
    "        return_value=True\n",
    "    )()\n",
    "    print(f\"Loss: {loss}\")\n",
    "    # Update parameters with optimizer (if defined)\n",
    "    optimizer.update(grads)\n",
    "else:\n",
    "    # If no trainable parameters, just compute loss\n",
    "    loss = loss_fn()\n",
    "    print(f\"Loss (no trainable params): {loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Physical Units System\n",
    "\n",
    "### Integration with saiunit\n",
    "\n",
    "``brainpy.state`` integrates ``saiunit`` for scientific accuracy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:09.386693Z",
     "start_time": "2025-11-13T09:31:09.382194Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.059734Z",
     "iopub.status.busy": "2026-05-11T06:19:37.059472Z",
     "iopub.status.idle": "2026-05-11T06:19:37.064489Z",
     "shell.execute_reply": "2026-05-11T06:19:37.063364Z"
    }
   },
   "outputs": [],
   "source": [
    "# Define with units\n",
    "tau = 10 * u.ms\n",
    "threshold = -50 * u.mV\n",
    "current = 5 * u.nA\n",
    "\n",
    "# Units are checked automatically\n",
    "neuron = brainpy.state.LIF(100, tau=tau, V_th=threshold)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Benefits:\n",
    "\n",
    "- Prevents unit errors (e.g., ms vs s)\n",
    "- Self-documenting code\n",
    "- Automatic unit conversions\n",
    "- Scientific correctness"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Unit Operations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:09.432404Z",
     "start_time": "2025-11-13T09:31:09.427692Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.067119Z",
     "iopub.status.busy": "2026-05-11T06:19:37.066849Z",
     "iopub.status.idle": "2026-05-11T06:19:37.083311Z",
     "shell.execute_reply": "2026-05-11T06:19:37.082377Z"
    }
   },
   "outputs": [],
   "source": [
    "# Arithmetic with units\n",
    "total_time = 100 * u.ms + 0.5 * u.second  # → 600 ms\n",
    "\n",
    "# Unit conversion\n",
    "time_in_seconds = (100 * u.ms).to_decimal(u.second)  # → 0.1\n",
    "\n",
    "# Unit checking (automatic in BrainPy operations)\n",
    "voltage = -65 * u.mV\n",
    "current = 2 * u.nA\n",
    "resistance = voltage / current  # Automatically gives MΩ"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ecosystem Integration\n",
    "\n",
    "``brainpy.state`` integrates tightly with its ecosystem:\n",
    "\n",
    "### braintools\n",
    "\n",
    "Utilities and tools:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:09.471630Z",
     "start_time": "2025-11-13T09:31:09.465919Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.085634Z",
     "iopub.status.busy": "2026-05-11T06:19:37.085380Z",
     "iopub.status.idle": "2026-05-11T06:19:37.089347Z",
     "shell.execute_reply": "2026-05-11T06:19:37.088417Z"
    }
   },
   "outputs": [],
   "source": [
    "# Optimizers\n",
    "optimizer = braintools.optim.Adam(lr=1e-3)\n",
    "\n",
    "# Initializers\n",
    "init = braintools.init.KaimingNormal()\n",
    "\n",
    "# Surrogate gradients\n",
    "spike_fn = braintools.surrogate.ReluGrad()\n",
    "\n",
    "# Metrics (example with dummy data)\n",
    "# pred = jnp.array([0.1, 0.9])\n",
    "# target = jnp.array([0, 1])\n",
    "# loss = braintools.metric.cross_entropy(pred, target)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### saiunit\n",
    "\n",
    "Physical units:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:09.496064Z",
     "start_time": "2025-11-13T09:31:09.490951Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.091255Z",
     "iopub.status.busy": "2026-05-11T06:19:37.091053Z",
     "iopub.status.idle": "2026-05-11T06:19:37.094309Z",
     "shell.execute_reply": "2026-05-11T06:19:37.093287Z"
    }
   },
   "outputs": [],
   "source": [
    "# All standard SI units\n",
    "time = 10 * u.ms\n",
    "voltage = -65 * u.mV\n",
    "current = 2 * u.nA"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### brainstate\n",
    "\n",
    "Core framework (used automatically):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T09:31:09.518362Z",
     "start_time": "2025-11-13T09:31:09.512404Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.096419Z",
     "iopub.status.busy": "2026-05-11T06:19:37.096207Z",
     "iopub.status.idle": "2026-05-11T06:19:37.100116Z",
     "shell.execute_reply": "2026-05-11T06:19:37.099334Z"
    }
   },
   "outputs": [],
   "source": [
    "import brainstate\n",
    "\n",
    "# Module system\n",
    "class Net(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        pass\n",
    "\n",
    "# Compilation\n",
    "@brainstate.transform.jit\n",
    "def fn():\n",
    "    return 0\n",
    "\n",
    "# Transformations\n",
    "# result = brainstate.transform.for_loop(...)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "42c8b66c76fc4d14977472fae4788d86": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "4c3ad8ac42d84cf5ab91d940be0ac446": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "50db3d97379845a69132eca4e50c8879": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_f18277ee6ba144d0b8da742d98cae5b9",
        "IPY_MODEL_fbc2f4370f7d48779ba38a5891b69d0e",
        "IPY_MODEL_dd48ea99ce9f4c2c99062b788c37ce58"
       ],
       "layout": "IPY_MODEL_42c8b66c76fc4d14977472fae4788d86",
       "tabbable": null,
       "tooltip": null
      }
     },
     "9307d4d265e44f089a44dc6d73efd672": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "9901cf5dbd594b89aa3cb176a5d2867e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "b8d718881731476a8780f9771ae399f3": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "dd48ea99ce9f4c2c99062b788c37ce58": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_b8d718881731476a8780f9771ae399f3",
       "placeholder": "​",
       "style": "IPY_MODEL_9901cf5dbd594b89aa3cb176a5d2867e",
       "tabbable": null,
       "tooltip": null,
       "value": " 10000/10000 [00:00&lt;00:00, 282688.38it/s]"
      }
     },
     "f060ed6079f0472cab47cd284c6357ed": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "f18277ee6ba144d0b8da742d98cae5b9": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_4c3ad8ac42d84cf5ab91d940be0ac446",
       "placeholder": "​",
       "style": "IPY_MODEL_f060ed6079f0472cab47cd284c6357ed",
       "tabbable": null,
       "tooltip": null,
       "value": "Running for 10,000 iterations: 100%"
      }
     },
     "f97d10eb383d4c029bdf05fe7c433567": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "fbc2f4370f7d48779ba38a5891b69d0e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_f97d10eb383d4c029bdf05fe7c433567",
       "max": 10000.0,
       "min": 0.0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_9307d4d265e44f089a44dc6d73efd672",
       "tabbable": null,
       "tooltip": null,
       "value": 10000.0
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
