{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f6",
   "metadata": {},
   "source": [
    "# Getting Started\n",
    "\n",
    "Welcome to **BrainState**! This tutorial will guide you through the basics of using BrainState, a state-based transformation system designed for brain modeling and neural network programming.\n",
    "\n",
    "By the end of this tutorial, you will:\n",
    "- Understand what BrainState is and why it's useful\n",
    "- Know how to install and set up BrainState\n",
    "- Learn the core concepts and design philosophy\n",
    "- Build your first simple neural network with BrainState"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2c3d4e5f6a7",
   "metadata": {},
   "source": [
    "## What is BrainState?\n",
    "\n",
    "**BrainState** is a powerful Python library built on top of JAX that provides:\n",
    "\n",
    "- 🧠 **Stateful Programming Model**: Manage mutable states in a JAX-compatible way\n",
    "- 🚀 **High Performance**: Leverage JAX's JIT compilation, automatic differentiation, and vectorization\n",
    "- 🔧 **Modular Design**: Build complex models from simple, composable components\n",
    "- 🌐 **Brain Modeling**: Specialized tools for computational neuroscience and brain-inspired computing\n",
    "\n",
    "BrainState bridges the gap between the functional programming paradigm of JAX and the intuitive, stateful programming style commonly used in neural network frameworks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3d4e5f6a7b8",
   "metadata": {},
   "source": [
    "## Installation and Environment Setup\n",
    "\n",
    "### Prerequisites\n",
    "\n",
    "Before installing BrainState, ensure you have:\n",
    "- Python 3.9 or higher\n",
    "- pip package manager\n",
    "\n",
    "### Installing BrainState\n",
    "\n",
    "The easiest way to install BrainState is via pip:\n",
    "\n",
    "```bash\n",
    "pip install brainstate --upgrade\n",
    "```\n",
    "\n",
    "### Installing the Complete Ecosystem\n",
    "\n",
    "For a complete brain modeling ecosystem, you can install BrainX, which bundles BrainState with other compatible packages:\n",
    "\n",
    "```bash\n",
    "pip install BrainX -U\n",
    "```\n",
    "\n",
    "This includes:\n",
    "- **brainstate**: Core state management and transformations\n",
    "- **brainunit**: Physical units and dimensional analysis\n",
    "- **braintools**: Optimization algorithms and utilities\n",
    "- **brainpy**: Spiking neural network modeling\n",
    "\n",
    "### Verifying Installation\n",
    "\n",
    "Let's verify that BrainState is installed correctly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d4e5f6a7b8c9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:00:50.820131Z",
     "start_time": "2025-10-10T10:00:47.845537Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BrainState version: 0.2.3\n",
      "Installation successful! ✓\n"
     ]
    }
   ],
   "source": [
    "import brainstate\n",
    "import braintools\n",
    "import jax.numpy as jnp\n",
    "\n",
    "print(f\"BrainState version: {brainstate.__version__}\")\n",
    "print(f\"Installation successful! ✓\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5f6a7b8c9d0",
   "metadata": {},
   "source": [
    "## Core Concepts Overview\n",
    "\n",
    "BrainState is built around several key concepts that work together to enable stateful, high-performance neural network programming.\n",
    "\n",
    "### 1. State: Managing Mutable Variables\n",
    "\n",
    "In pure functional programming (like JAX), all data is immutable. However, neural networks and brain models inherently involve mutable states (e.g., neuron membrane potentials, network weights).\n",
    "\n",
    "**BrainState's `State`** provides a solution by wrapping mutable variables in a way that's compatible with JAX transformations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f6a7b8c9d0e1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:00:50.943097Z",
     "start_time": "2025-10-10T10:00:50.828845Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial voltage: [  0. -70. -55.]\n",
      "Updated voltage: [ 10. -60. -45.]\n"
     ]
    }
   ],
   "source": [
    "# Creating a State object\n",
    "voltage = brainstate.State(jnp.array([0.0, -70.0, -55.0]))\n",
    "print(\"Initial voltage:\", voltage.value)\n",
    "\n",
    "# Updating the state\n",
    "voltage.value = voltage.value + 10.0\n",
    "print(\"Updated voltage:\", voltage.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7b8c9d0e1f2",
   "metadata": {},
   "source": [
    "**Key Types of States:**\n",
    "\n",
    "- `State`: Generic mutable state\n",
    "- `ParamState`: Trainable parameters (weights, biases)\n",
    "- `HiddenState`: Hidden activations (membrane potentials, hidden layer outputs)\n",
    "- `ShortTermState`: Temporary states (spike times, current values)\n",
    "- `LongTermState`: Long-term states (running statistics, momentum)\n",
    "\n",
    "We'll explore these in detail in the next tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8c9d0e1f2a3",
   "metadata": {},
   "source": [
    "### 2. Module: Building Blocks of Neural Networks\n",
    "\n",
    "The `Module` class (actually `graph.Node`) is the base class for all neural network components in BrainState. It automatically manages states and provides a clean interface for building complex models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c9d0e1f2a3b4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:00:51.038458Z",
     "start_time": "2025-10-10T10:00:50.950516Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial voltage: 0.0\n",
      "Spike at time 6! V=0.0\n",
      "Spike at time 13! V=0.0\n"
     ]
    }
   ],
   "source": [
    "class SimpleNeuron(brainstate.nn.Module):\n",
    "    \"\"\"A simple leaky integrate-and-fire neuron.\"\"\"\n",
    "    \n",
    "    def __init__(self, threshold=1.0, reset=0.0, tau=10.0):\n",
    "        super().__init__()\n",
    "        self.threshold = threshold\n",
    "        self.reset = reset\n",
    "        self.tau = tau\n",
    "        \n",
    "        # Membrane potential is a hidden state\n",
    "        self.V = brainstate.HiddenState(jnp.array(0.0))\n",
    "    \n",
    "    def __call__(self, I_input):\n",
    "        \"\"\"Update neuron state given input current.\"\"\"\n",
    "        # Leaky integration\n",
    "        dV = (-self.V.value + I_input) / self.tau\n",
    "        self.V.value = self.V.value + dV\n",
    "        \n",
    "        # Spike and reset\n",
    "        spike = self.V.value >= self.threshold\n",
    "        self.V.value = jnp.where(spike, self.reset, self.V.value)\n",
    "        \n",
    "        return spike\n",
    "\n",
    "# Create and test the neuron\n",
    "neuron = SimpleNeuron()\n",
    "print(\"Initial voltage:\", neuron.V.value)\n",
    "\n",
    "# Simulate with input current\n",
    "for t in range(20):\n",
    "    spike = neuron(2.0)  # constant input\n",
    "    if spike:\n",
    "        print(f\"Spike at time {t}! V={neuron.V.value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0e1f2a3b4c5",
   "metadata": {},
   "source": [
    "### 3. Transform: JAX Transformations with States\n",
    "\n",
    "BrainState provides state-aware versions of JAX transformations:\n",
    "\n",
    "- `brainstate.transform.jit`: Just-in-time compilation\n",
    "- `brainstate.transform.grad`: Automatic differentiation\n",
    "- `brainstate.transform.vmap`: Vectorization (batching)\n",
    "- `brainstate.transform.scan`: Efficient loops\n",
    "\n",
    "These transformations automatically handle state management for you."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e1f2a3b4c5d6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:00:51.125716Z",
     "start_time": "2025-10-10T10:00:51.059865Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Spike train: [0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0]\n"
     ]
    }
   ],
   "source": [
    "# Reset neuron\n",
    "neuron.V.value = jnp.array(0.0)\n",
    "\n",
    "# Simulate with varying input\n",
    "inputs = jnp.array([1.5, 2.0, 2.5, 3.0, 1.0] * 4)\n",
    "spikes = brainstate.transform.for_loop(neuron, inputs)\n",
    "print(\"Spike train:\", spikes.astype(int))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2a3b4c5d6e7",
   "metadata": {},
   "source": [
    "### 4. Random: Stateful Random Number Generation\n",
    "\n",
    "BrainState provides a stateful random number generator that's compatible with JAX's functional random number generation while maintaining a simple, NumPy-like interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a3b4c5d6e7f8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:00:51.425858Z",
     "start_time": "2025-10-10T10:00:51.133928Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Uniform samples: [0.72766423 0.78786755 0.18169427 0.26263022 0.11072934]\n",
      "Normal samples: [-0.21089035 -1.3627948  -0.04500385 -1.1536394   1.9141139 ]\n"
     ]
    }
   ],
   "source": [
    "# Set random seed for reproducibility\n",
    "brainstate.random.seed(42)\n",
    "\n",
    "# Generate random numbers\n",
    "uniform_samples = brainstate.random.rand(5)\n",
    "normal_samples = brainstate.random.randn(5)\n",
    "\n",
    "print(\"Uniform samples:\", uniform_samples)\n",
    "print(\"Normal samples:\", normal_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4c5d6e7f8a9",
   "metadata": {},
   "source": [
    "## Hello World: Building Your First Neural Network\n",
    "\n",
    "Let's build a simple feedforward neural network to classify handwritten digits. This example demonstrates the key concepts working together.\n",
    "\n",
    "### Step 1: Define the Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c5d6e7f8a9b0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:00:51.774758Z",
     "start_time": "2025-10-10T10:00:51.434028Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network created!\n",
      "Total parameters: 101,770\n"
     ]
    }
   ],
   "source": [
    "class MLP(brainstate.nn.Module):\n",
    "    \"\"\"A simple multi-layer perceptron.\"\"\"\n",
    "    \n",
    "    def __init__(self, input_dim, hidden_dim, output_dim):\n",
    "        super().__init__()\n",
    "        \n",
    "        # Initialize weights and biases as trainable parameters\n",
    "        self.w1 = brainstate.ParamState(brainstate.random.randn(input_dim, hidden_dim) * 0.1)\n",
    "        self.b1 = brainstate.ParamState(jnp.zeros(hidden_dim))\n",
    "        \n",
    "        self.w2 = brainstate.ParamState(brainstate.random.randn(hidden_dim, output_dim) * 0.1)\n",
    "        self.b2 = brainstate.ParamState(jnp.zeros(output_dim))\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        \"\"\"Forward pass through the network.\"\"\"\n",
    "        # Hidden layer with ReLU activation\n",
    "        hidden = jnp.maximum(0, x @ self.w1.value + self.b1.value)\n",
    "        \n",
    "        # Output layer\n",
    "        logits = hidden @ self.w2.value + self.b2.value\n",
    "        \n",
    "        return logits\n",
    "\n",
    "# Create the network\n",
    "brainstate.random.seed(0)\n",
    "model = MLP(input_dim=784, hidden_dim=128, output_dim=10)\n",
    "print(\"Network created!\")\n",
    "print(f\"Total parameters: {784*128 + 128 + 128*10 + 10:,}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6e7f8a9b0c1",
   "metadata": {},
   "source": [
    "### Step 2: Define Loss Function and Training Step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e7f8a9b0c1d2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:00:51.786579Z",
     "start_time": "2025-10-10T10:00:51.781764Z"
    }
   },
   "outputs": [],
   "source": [
    "def cross_entropy_loss(logits, labels):\n",
    "    \"\"\"Compute cross-entropy loss.\"\"\"\n",
    "    # One-hot encode labels\n",
    "    one_hot_labels = jnp.eye(10)[labels]\n",
    "    \n",
    "    # Compute log-softmax\n",
    "    log_probs = logits - jnp.log(jnp.sum(jnp.exp(logits), axis=-1, keepdims=True))\n",
    "    \n",
    "    # Compute loss\n",
    "    loss = -jnp.mean(jnp.sum(one_hot_labels * log_probs, axis=-1))\n",
    "    return loss\n",
    "\n",
    "def accuracy(logits, labels):\n",
    "    \"\"\"Compute classification accuracy.\"\"\"\n",
    "    predictions = jnp.argmax(logits, axis=-1)\n",
    "    return jnp.mean(predictions == labels)\n",
    "\n",
    "def loss_fn(x, y):\n",
    "    \"\"\"Compute loss for the model.\"\"\"\n",
    "    logits = model(x)\n",
    "    return cross_entropy_loss(logits, y)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "64b40afff963d61b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:00:52.135665Z",
     "start_time": "2025-10-10T10:00:51.804060Z"
    }
   },
   "outputs": [],
   "source": [
    "# Generate dummy data for demonstration\n",
    "brainstate.random.seed(42)\n",
    "X_train = brainstate.random.randn(100, 784) * 0.1  # 100 samples\n",
    "y_train = brainstate.random.randint(0, 10, 100)     # Random labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e80c02d8340e86cc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:00:52.156399Z",
     "start_time": "2025-10-10T10:00:52.142640Z"
    }
   },
   "outputs": [],
   "source": [
    "# Create gradient function\n",
    "param_states = brainstate.transform.StateFinder(loss_fn, brainstate.ParamState)(X_train, y_train)\n",
    "grad_fn = brainstate.transform.grad(loss_fn, grad_states=param_states)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8a9b0c1d2e3",
   "metadata": {},
   "source": [
    "### Step 3: Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f175f290b7246984",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:01:50.993042Z",
     "start_time": "2025-10-10T10:01:50.988317Z"
    }
   },
   "outputs": [],
   "source": [
    "optimizer = braintools.optim.SGD(1e-1)\n",
    "_ = optimizer.register_trainable_weights(param_states)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a9b0c1d2e3f4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:01:51.778086Z",
     "start_time": "2025-10-10T10:01:51.583763Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting training...\n",
      "\n",
      "Epoch  2: Loss = 2.2960, Accuracy = 0.1000\n",
      "Epoch  4: Loss = 2.2739, Accuracy = 0.1200\n",
      "Epoch  6: Loss = 2.2529, Accuracy = 0.1500\n",
      "Epoch  8: Loss = 2.2326, Accuracy = 0.1900\n",
      "Epoch 10: Loss = 2.2130, Accuracy = 0.2300\n",
      "\n",
      "Training complete!\n"
     ]
    }
   ],
   "source": [
    "@brainstate.transform.jit\n",
    "def train_step(x, y):\n",
    "    \"\"\"Perform one training step.\"\"\"\n",
    "    # Compute gradients\n",
    "    grads = grad_fn(x, y)\n",
    "    \n",
    "    # Update parameters using gradient descent\n",
    "    optimizer.update(grads)\n",
    "    \n",
    "    # Compute metrics\n",
    "    logits = model(x)\n",
    "    loss = cross_entropy_loss(logits, y)\n",
    "    acc = accuracy(logits, y)\n",
    "    \n",
    "    return loss, acc\n",
    "\n",
    "# Training loop\n",
    "print(\"Starting training...\\n\")\n",
    "for epoch in range(10):\n",
    "    loss, acc = train_step(X_train, y_train)\n",
    "    if (epoch + 1) % 2 == 0:\n",
    "        print(f\"Epoch {epoch+1:2d}: Loss = {loss:.4f}, Accuracy = {acc:.4f}\")\n",
    "\n",
    "print(\"\\nTraining complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0c1d2e3f4a5",
   "metadata": {},
   "source": [
    "### Step 4: Making Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c1d2e3f4a5b6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:02:03.719728Z",
     "start_time": "2025-10-10T10:02:03.447690Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predictions on test data:\n",
      "[3 9 7 7 7 7 7 9 5 0]\n"
     ]
    }
   ],
   "source": [
    "@brainstate.transform.jit\n",
    "def predict(x):\n",
    "    \"\"\"Make predictions with the model.\"\"\"\n",
    "    logits = model(x)\n",
    "    return jnp.argmax(logits, axis=-1)\n",
    "\n",
    "# Generate test data\n",
    "X_test = brainstate.random.randn(10, 784) * 0.1\n",
    "predictions = predict(X_test)\n",
    "\n",
    "print(\"Predictions on test data:\")\n",
    "print(predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2e3f4a5b6c7",
   "metadata": {},
   "source": [
    "## Key Takeaways\n",
    "\n",
    "Congratulations! You've just built your first neural network with BrainState. Here are the key concepts we covered:\n",
    "\n",
    "1. **States** wrap mutable variables and make them compatible with JAX transformations\n",
    "2. **Modules** (via `nn.Module`) provide a clean way to organize neural network components\n",
    "3. **Transformations** like `jit` and `grad` work seamlessly with stateful code\n",
    "4. **Random number generation** is stateful yet reproducible\n",
    "\n",
    "## What's Next?\n",
    "\n",
    "Now that you understand the basics, continue with the following tutorials:\n",
    "\n",
    "1. **State Management** - Deep dive into different types of states and advanced state management techniques\n",
    "2. **Random Number Generation** - Learn about BrainState's random number generation system\n",
    "3. **Neural Network Modules** - Explore pre-built layers and learn to create custom modules\n",
    "4. **Program Transformations** - Master JIT compilation, automatic differentiation, and vectorization\n",
    "\n",
    "## Additional Resources\n",
    "\n",
    "- 📚 [BrainState Documentation](https://brainstate.readthedocs.io/)\n",
    "- 🌐 [BrainX Ecosystem](https://brainmodeling.readthedocs.io/)\n",
    "- 💻 [GitHub Repository](https://github.com/chaobrain/brainstate)\n",
    "- 🐛 [Issue Tracker](https://github.com/chaobrain/brainstate/issues)\n",
    "\n",
    "Happy coding with BrainState! 🧠✨"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
