{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cc708b848687e9db",
   "metadata": {},
   "source": [
    "# State Management\n",
    "\n",
    "In dynamical brain modeling, time-varying state variables are often encountered, such as the membrane potential `V` of neurons or the firing rate `r` in firing rate models. **BrainState** provides the `State` data structure, which helps users intuitively define and manage computational states.\n",
    "\n",
    "This tutorial provides a detailed introduction to state management in BrainState. By following this tutorial, you will learn:\n",
    "\n",
    "- The basic concepts and fundamental usage of `State` objects\n",
    "- How to create `State` objects and use its subclasses: `ShortTermState`, `LongTermState`, `HiddenState`, and `ParamState`\n",
    "- State and JAX PyTree compatibility\n",
    "- How to use `StateTraceStack` to track State objects in your programs\n",
    "- Advanced state management patterns with `StateDictManager`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b33843f51bfdecdd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:04.517725Z",
     "start_time": "2025-10-10T10:04:02.896222Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "import brainstate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db2715ae6e9f10d9",
   "metadata": {},
   "source": [
    "## 1. Basic Concepts and Usage of State Objects\n",
    "\n",
    "`State` is a key data structure in **BrainState** used to encapsulate state variables in models. These variables primarily represent values that change over time within the model.\n",
    "\n",
    "### Why States?\n",
    "\n",
    "JAX is built on functional programming principles, which means:\n",
    "- All data is immutable by default\n",
    "- Functions cannot have side effects\n",
    "- State must be explicitly threaded through computations\n",
    "\n",
    "This creates a challenge for neural network programming, where we naturally think in terms of mutable states (weights, neuron voltages, etc.). **BrainState's `State`** solves this by:\n",
    "\n",
    "✅ Providing a mutable interface for state variables  \n",
    "✅ Automatically managing state updates during JAX transformations  \n",
    "✅ Maintaining compatibility with JAX's functional paradigm  \n",
    "\n",
    "### Creating States\n",
    "\n",
    "A `State` can wrap any Python data type, such as integers, floating-point numbers, arrays, `jax.Array`, or any of these encapsulated in dictionaries or lists. Unlike native Python data structures, the data within a `State` object remains mutable after program compilation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3a6b10691cf03e3b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:04.642100Z",
     "start_time": "2025-10-10T10:04:04.525569Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "State(\n",
       "  value=ShapedArray(float32[10])\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Create a simple State with an array\n",
    "example = brainstate.State(jnp.ones(10))\n",
    "example"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba96083a3de0451f",
   "metadata": {},
   "source": [
    "### States and PyTrees\n",
    "\n",
    "`State` supports arbitrary [PyTree](https://jax.readthedocs.io/en/latest/working-with-pytrees.html) structures, which means you can encapsulate complex nested data structures within a `State` object. This is particularly useful for models with hierarchical state representations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9c63ad908b861dc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:04.703292Z",
     "start_time": "2025-10-10T10:04:04.663370Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "State(\n",
       "  value={\n",
       "    'a': ShapedArray(float32[3]),\n",
       "    'b': ShapedArray(float32[4])\n",
       "  }\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# State can hold complex PyTree structures\n",
    "example2 = brainstate.State({'a': jnp.ones(3), 'b': jnp.zeros(4)})\n",
    "example2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "pytree_example",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:04.781969Z",
     "start_time": "2025-10-10T10:04:04.706798Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Complex state structure:\n",
      "State(\n",
      "  value={\n",
      "    'neurons': {\n",
      "      'V': ShapedArray(float32[100]),\n",
      "      'u': ShapedArray(float32[100])\n",
      "    },\n",
      "    'synapses': {\n",
      "      'g': ShapedArray(float32[100,100]),\n",
      "      'weights': ShapedArray(float32[100,100])\n",
      "    }\n",
      "  }\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# State can also hold nested structures\n",
    "complex_state = brainstate.State({\n",
    "    'neurons': {\n",
    "        'V': jnp.zeros(100),\n",
    "        'u': jnp.zeros(100)\n",
    "    },\n",
    "    'synapses': {\n",
    "        'g': jnp.zeros((100, 100)),\n",
    "        'weights': jnp.ones((100, 100)) * 0.1\n",
    "    }\n",
    "})\n",
    "print(\"Complex state structure:\")\n",
    "print(complex_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "330fafcd49aa2712",
   "metadata": {},
   "source": [
    "### Accessing and Updating States\n",
    "\n",
    "Users can access and modify state data through the `State.value` attribute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "558f8e729e2ccdb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:04.797286Z",
     "start_time": "2025-10-10T10:04:04.792191Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current value: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n"
     ]
    }
   ],
   "source": [
    "# Access the state value\n",
    "print(\"Current value:\", example.value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "de54d47f46b09325",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:05.011882Z",
     "start_time": "2025-10-10T10:04:04.821731Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Updated state:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "State(\n",
       "  value=ShapedArray(float32[3])\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Update the state value\n",
    "example.value = brainstate.random.random(3)\n",
    "print(\"Updated state:\")\n",
    "example"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d29f342a73eb2ee",
   "metadata": {},
   "source": [
    "### Core Features of State\n",
    "\n",
    "**✅ Mutable after compilation**: State values can be updated even in JIT-compiled functions\n",
    "\n",
    "**✅ Type and shape safety**: States enforce consistent types and shapes\n",
    "\n",
    "**✅ Integration with JAX**: Works seamlessly with JAX transformations\n",
    "\n",
    "### Important Notes\n",
    "\n",
    "⚠️ **Static Data in JIT Compilation**: Any data not marked as a state variable will be treated as static during JIT compilation. Modifying static data in a JIT-compiled environment has no effect.\n",
    "\n",
    "⚠️ **Constraints on Modifying State Data**: When updating via the `value` attribute, the assigned data must have the same PyTree structure as the original. The shape and dtype should generally match, though some flexibility is allowed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4159d20ade4f2bdb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:05.101311Z",
     "start_time": "2025-10-10T10:04:05.023371Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✓ Successfully updated state with matching structure\n",
      "✗ Error: The given value PyTreeDef((*, *)) does not match with the origin tree structure PyTreeDef(*).\n"
     ]
    }
   ],
   "source": [
    "# Demonstrate tree structure checking\n",
    "state = brainstate.ShortTermState(jnp.zeros((2, 3)))\n",
    "\n",
    "with brainstate.check_state_value_tree():\n",
    "    # This works - same tree structure\n",
    "    state.value = jnp.zeros((2, 3))\n",
    "    print(\"✓ Successfully updated state with matching structure\")\n",
    "    \n",
    "    # This fails - different tree structure\n",
    "    try:\n",
    "        state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))\n",
    "    except Exception as e:\n",
    "        print(f\"✗ Error: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fc49d593597132",
   "metadata": {},
   "source": [
    "## 2. Subclasses of State\n",
    "\n",
    "**BrainState** provides several subclasses of `State` to help organize different types of state variables in your models. While these subclasses are functionally identical to the base `State` class, they serve as semantic markers that:\n",
    "\n",
    "- 📝 Improve code readability\n",
    "- 🔍 Enable selective filtering (e.g., finding all trainable parameters)\n",
    "- 🎯 Clarify the role of each state variable\n",
    "\n",
    "### Overview of State Types\n",
    "\n",
    "| State Type | Purpose | Examples |\n",
    "|------------|---------|----------|\n",
    "| `ParamState` | Trainable parameters | Weights, biases |\n",
    "| `HiddenState` | Hidden activations | Membrane potentials, RNN hidden states |\n",
    "| `ShortTermState` | Transient states | Last spike time, current input |\n",
    "| `LongTermState` | Persistent states | Running averages, momentum |\n",
    "\n",
    "### 2.1 ParamState - Trainable Parameters\n",
    "\n",
    "`ParamState` is used for trainable parameters in neural networks. These are the values that get updated during training via gradient descent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "13df56647fb5434d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:05.267487Z",
     "start_time": "2025-10-10T10:04:05.106112Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weight:\n",
      "ParamState(\n",
      "  value=ShapedArray(float32[10,10])\n",
      ")\n",
      "\n",
      "Bias:\n",
      "ParamState(\n",
      "  value=ShapedArray(float32[10])\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Example: Neural network parameters\n",
    "weight = brainstate.ParamState(brainstate.random.randn(10, 10) * 0.1)\n",
    "bias = brainstate.ParamState(jnp.zeros(10))\n",
    "\n",
    "print(\"Weight:\")\n",
    "print(weight)\n",
    "print(\"\\nBias:\")\n",
    "print(bias)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ca2fb96e9d2ab44",
   "metadata": {},
   "source": [
    "### 2.2 HiddenState - Hidden Activations\n",
    "\n",
    "`HiddenState` encapsulates hidden activation variables in models. These states are updated during every simulation iteration and retained between iterations, representing the internal dynamics of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "711dc3f1ecb61594",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:05.325878Z",
     "start_time": "2025-10-10T10:04:05.272284Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Membrane potential:\n",
      "HiddenState(\n",
      "  value=ShapedArray(float32[10], weak_type=True)\n",
      ")\n",
      "\n",
      "RNN hidden state:\n",
      "HiddenState(\n",
      "  value=ShapedArray(float32[32,128])\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Example: Neuron membrane potential\n",
    "V = brainstate.HiddenState(jnp.full(10, -70.0))  # Resting potential\n",
    "\n",
    "# Example: RNN hidden state\n",
    "h = brainstate.HiddenState(jnp.zeros((32, 128)))  # (batch_size, hidden_dim)\n",
    "\n",
    "print(\"Membrane potential:\")\n",
    "print(V)\n",
    "print(\"\\nRNN hidden state:\")\n",
    "print(h)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fc49d593597132_2",
   "metadata": {},
   "source": [
    "### 2.3 ShortTermState - Transient States\n",
    "\n",
    "`ShortTermState` is designed for short-term, transient state variables. These states capture instantaneous values that may not carry long-term dependencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "abcfc3b1a9aba883",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:05.360859Z",
     "start_time": "2025-10-10T10:04:05.353966Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Last spike times:\n",
      "ShortTermState(\n",
      "  value=ShapedArray(float32[10], weak_type=True)\n",
      ")\n",
      "\n",
      "Current input:\n",
      "ShortTermState(\n",
      "  value=ShapedArray(float32[10])\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Example: Last spike time\n",
    "t_last_spike = brainstate.ShortTermState(jnp.full(10, -1e7))  # Very old time\n",
    "\n",
    "# Example: Current input\n",
    "current_input = brainstate.ShortTermState(jnp.zeros(10))\n",
    "\n",
    "print(\"Last spike times:\")\n",
    "print(t_last_spike)\n",
    "print(\"\\nCurrent input:\")\n",
    "print(current_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab7a845cfa4313bb",
   "metadata": {},
   "source": [
    "### 2.4 LongTermState - Persistent States\n",
    "\n",
    "`LongTermState` is used for long-term state variables that accumulate information over many iterations. These are commonly used for statistics tracking and optimization algorithms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c5e750e5fe2970c6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:05.394335Z",
     "start_time": "2025-10-10T10:04:05.362985Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running mean:\n",
      "LongTermState(\n",
      "  value=ShapedArray(float32[64])\n",
      ")\n",
      "\n",
      "Momentum:\n",
      "LongTermState(\n",
      "  value=ShapedArray(float32[100,100])\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Example: Running mean for batch normalization\n",
    "running_mean = brainstate.LongTermState(jnp.zeros(64))\n",
    "running_var = brainstate.LongTermState(jnp.ones(64))\n",
    "\n",
    "# Example: Optimizer momentum\n",
    "momentum = brainstate.LongTermState(jnp.zeros((100, 100)))\n",
    "\n",
    "print(\"Running mean:\")\n",
    "print(running_mean)\n",
    "print(\"\\nMomentum:\")\n",
    "print(momentum)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "practical_example",
   "metadata": {},
   "source": [
    "### Practical Example: LIF Neuron Model\n",
    "\n",
    "Let's see how different state types work together in a realistic model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "lif_example",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:05.780898Z",
     "start_time": "2025-10-10T10:04:05.401758Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial state:\n",
      "V: [0. 0. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "class LIFNeuron(brainstate.nn.Module):\n",
    "    \"\"\"Leaky Integrate-and-Fire neuron model.\"\"\"\n",
    "    \n",
    "    def __init__(self, n_neurons, tau=10.0, V_th=1.0, V_reset=0.0):\n",
    "        super().__init__()\n",
    "        self.tau = tau\n",
    "        self.V_th = V_th\n",
    "        self.V_reset = V_reset\n",
    "        \n",
    "        # Hidden state: membrane potential (evolves continuously)\n",
    "        self.V = brainstate.HiddenState(jnp.full(n_neurons, V_reset))\n",
    "        \n",
    "        # Short-term state: refractory period counter\n",
    "        self.t_last_spike = brainstate.ShortTermState(jnp.full(n_neurons, -1e7))\n",
    "        \n",
    "        # Parameters: input weights\n",
    "        self.w_in = brainstate.ParamState(brainstate.random.randn(n_neurons, n_neurons) * 0.1)\n",
    "    \n",
    "    def __call__(self, I_ext, t):\n",
    "        # Membrane potential dynamics\n",
    "        dV = (-self.V.value + I_ext) / self.tau\n",
    "        self.V.value = self.V.value + dV\n",
    "        \n",
    "        # Spike generation\n",
    "        spike = self.V.value >= self.V_th\n",
    "        \n",
    "        # Reset\n",
    "        self.V.value = jnp.where(spike, self.V_reset, self.V.value)\n",
    "        self.t_last_spike.value = jnp.where(spike, t, self.t_last_spike.value)\n",
    "        \n",
    "        return spike\n",
    "\n",
    "# Create and test the neuron\n",
    "neuron = LIFNeuron(n_neurons=5)\n",
    "print(\"Initial state:\")\n",
    "print(f\"V: {neuron.V.value}\")\n",
    "\n",
    "# Simulate\n",
    "for t in range(20):\n",
    "    I_ext = jnp.ones(5) * 0.2  # External current\n",
    "    spikes = neuron(I_ext, t)\n",
    "    if jnp.any(spikes):\n",
    "        print(f\"t={t}: Spikes at neurons {jnp.where(spikes)[0]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7c738bf3d3cb1ea",
   "metadata": {},
   "source": [
    "## 3. State Tracking with StateTraceStack\n",
    "\n",
    "`StateTraceStack` is a powerful debugging and introspection tool that tracks which `State` objects are accessed during program execution.\n",
    "\n",
    "### Why Track States?\n",
    "\n",
    "- 🔍 **Debugging**: Understand which states are being read/written\n",
    "- 📊 **Profiling**: Identify state access patterns\n",
    "- 🎯 **Selective updates**: Apply operations only to specific state types\n",
    "- 🧪 **Testing**: Verify expected state interactions\n",
    "\n",
    "### Basic Usage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7073382edc49a0aa",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:06.024112Z",
     "start_time": "2025-10-10T10:04:05.787349Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "States read: 2\n",
      "States written: 2\n"
     ]
    }
   ],
   "source": [
    "class Linear(brainstate.nn.Module):\n",
    "    def __init__(self, d_in, d_out):\n",
    "        super().__init__()\n",
    "        self.w = brainstate.ParamState(brainstate.random.randn(d_in, d_out) * 0.1)\n",
    "        self.b = brainstate.ParamState(jnp.zeros(d_out))\n",
    "        self.y = brainstate.HiddenState(jnp.zeros(d_out))\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        self.y.value = x @ self.w.value + self.b.value\n",
    "        return self.y.value\n",
    "\n",
    "model = Linear(2, 5)\n",
    "\n",
    "# Track state access\n",
    "with brainstate.StateTraceStack() as stack:\n",
    "    output = model(brainstate.random.randn(2))\n",
    "    \n",
    "    # Get accessed states\n",
    "    read_states = list(stack.get_read_states())\n",
    "    write_states = list(stack.get_write_states())\n",
    "\n",
    "print(f\"States read: {len(read_states)}\")\n",
    "print(f\"States written: {len(write_states)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93f47f34ffb76a62",
   "metadata": {},
   "source": [
    "### Inspecting State Access\n",
    "\n",
    "`StateTraceStack` provides four main methods:\n",
    "\n",
    "- `get_read_states()`: Returns State objects that were read\n",
    "- `get_read_state_values()`: Returns the values of read states\n",
    "- `get_write_states()`: Returns State objects that were written\n",
    "- `get_write_state_values()`: Returns the values of written states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "fb38926d3d364535",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:06.034909Z",
     "start_time": "2025-10-10T10:04:06.029509Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Read States ===\n",
      "1. ParamState: shape=(2, 5)\n",
      "2. ParamState: shape=(5,)\n"
     ]
    }
   ],
   "source": [
    "# Inspect read states\n",
    "print(\"=== Read States ===\")\n",
    "for i, state in enumerate(read_states):\n",
    "    print(f\"{i+1}. {type(state).__name__}: shape={state.value.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7bc150b0c2937665",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T10:04:06.089619Z",
     "start_time": "2025-10-10T10:04:06.085314Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Written States ===\n",
      "1. RandomState: shape=(2,)\n",
      "2. HiddenState: shape=(5,)\n"
     ]
    }
   ],
   "source": [
    "# Inspect written states\n",
    "print(\"=== Written States ===\")\n",
    "for i, state in enumerate(write_states):\n",
    "    print(f\"{i+1}. {type(state).__name__}: shape={state.value.shape if hasattr(state.value, 'shape') else 'N/A'}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "summary",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "In this tutorial, you learned:\n",
    "\n",
    "✅ **States** provide mutable variables compatible with JAX  \n",
    "✅ Different **state types** serve different purposes:  \n",
    "  - `ParamState` for trainable parameters  \n",
    "  - `HiddenState` for hidden activations  \n",
    "  - `ShortTermState` for transient states  \n",
    "  - `LongTermState` for persistent states  \n",
    "✅ **StateTraceStack** tracks state access for debugging  \n",
    "✅ States support **PyTree structures** for complex data  \n",
    "\n",
    "### Best Practices\n",
    "\n",
    "1. 🎯 Use specific state types (`ParamState`, etc.) rather than generic `State`\n",
    "2. 📝 Keep state updates simple and explicit\n",
    "3. 🔍 Use `StateTraceStack` for debugging unexpected behavior\n",
    "4. ⚠️ Remember: only `State` values are mutable; regular variables are static\n",
    "\n",
    "### Next Steps\n",
    "\n",
    "Continue with:\n",
    "- **Random Number Generation** - Learn about stateful random number generation\n",
    "- **Neural Network Modules** - Build complex models using states\n",
    "- **Program Transformations** - Use states with JIT, grad, and vmap"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
