{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 8: Animation and Dynamics\n",
    "\n",
    "This comprehensive tutorial demonstrates how to create dynamic visualizations and animations of neural data using BrainTools. We'll explore techniques for visualizing temporal evolution, creating movies of network dynamics, and optimizing performance for large datasets.\n",
    "\n",
    "## Learning Objectives\n",
    "\n",
    "By the end of this tutorial, you will be able to:\n",
    "- Create animated visualizations of neural activity over time\n",
    "- Apply 1D and 2D animation techniques to neural data\n",
    "- Generate movies of evolving network dynamics\n",
    "- Visualize learning processes through time-lapse animations\n",
    "- Display dynamic changes in connectivity patterns\n",
    "- Export animations in various video formats\n",
    "- Optimize performance for large-scale temporal datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup and Imports <a id='setup'></a>"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from IPython.display import HTML, display\n",
    "from matplotlib.animation import FuncAnimation, PillowWriter, FFMpegWriter\n",
    "\n",
    "import braintools\n",
    "\n",
    "# Set up output directory\n",
    "output_dir = Path('animations')\n",
    "output_dir.mkdir(exist_ok=True)\n",
    "\n",
    "# Enable interactive matplotlib in Jupyter\n",
    "%matplotlib inline\n",
    "\n",
    "# For animations in notebook\n",
    "from matplotlib import rc\n",
    "\n",
    "rc('animation', html='jshtml')\n",
    "\n",
    "# Set random seed\n",
    "np.random.seed(42)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Generate Dynamic Neural Data"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Generate synthetic dynamic neural data\n",
    "\n",
    "def generate_dynamic_data(n_frames=100, n_neurons=50, duration=10.0):\n",
    "    \"\"\"Generate time-varying neural data for animation demonstrations.\"\"\"\n",
    "\n",
    "    data = {}\n",
    "\n",
    "    # 1. Time-varying spike trains (propagating wave)\n",
    "    spike_data = []\n",
    "    for frame in range(n_frames):\n",
    "        frame_spikes = []\n",
    "        wave_center = frame / n_frames * duration\n",
    "\n",
    "        for neuron in range(n_neurons):\n",
    "            # Create wave of activity\n",
    "            phase = 2 * np.pi * neuron / n_neurons\n",
    "            rate = 10 * (1 + np.sin(2 * np.pi * frame / n_frames + phase))\n",
    "\n",
    "            n_spikes = np.random.poisson(rate * duration / n_frames)\n",
    "            spike_times = np.random.uniform(\n",
    "                wave_center - 0.5, wave_center + 0.5, n_spikes\n",
    "            )\n",
    "            spike_times = spike_times[(spike_times >= 0) & (spike_times <= duration)]\n",
    "            frame_spikes.append(spike_times)\n",
    "\n",
    "        spike_data.append(frame_spikes)\n",
    "    data['spike_trains'] = spike_data\n",
    "\n",
    "    # 2. Oscillating membrane potential\n",
    "    time_points = np.linspace(0, duration, 1000)\n",
    "    membrane_data = np.zeros((n_frames, len(time_points)))\n",
    "\n",
    "    for frame in range(n_frames):\n",
    "        phase_shift = 2 * np.pi * frame / n_frames\n",
    "        base = -70 + 5 * np.sin(phase_shift)\n",
    "        oscillation = 10 * np.sin(2 * np.pi * time_points + phase_shift)\n",
    "        noise = np.random.normal(0, 2, len(time_points))\n",
    "        membrane_data[frame] = base + oscillation + noise\n",
    "\n",
    "    data['membrane_potential'] = membrane_data\n",
    "    data['time'] = time_points\n",
    "\n",
    "    # 3. 2D activity patterns (traveling waves)\n",
    "    grid_size = 50\n",
    "    activity_frames = np.zeros((n_frames, grid_size, grid_size))\n",
    "\n",
    "    for frame in range(n_frames):\n",
    "        x = np.linspace(-2, 2, grid_size)\n",
    "        y = np.linspace(-2, 2, grid_size)\n",
    "        X, Y = np.meshgrid(x, y)\n",
    "\n",
    "        # Traveling wave\n",
    "        wave_pos = -2 + 4 * frame / n_frames\n",
    "        activity = np.exp(-((X - wave_pos) ** 2 + Y ** 2) / 0.5)\n",
    "\n",
    "        # Rotating spiral\n",
    "        theta = np.arctan2(Y, X) + 2 * np.pi * frame / n_frames\n",
    "        spiral = 0.5 * (1 + np.sin(3 * theta - np.sqrt(X ** 2 + Y ** 2)))\n",
    "\n",
    "        activity_frames[frame] = activity + 0.3 * spiral\n",
    "        activity_frames[frame] += np.random.normal(0, 0.1, (grid_size, grid_size))\n",
    "\n",
    "    data['activity_2d'] = activity_frames\n",
    "\n",
    "    # 4. Evolving connectivity\n",
    "    connectivity_frames = np.zeros((n_frames, 20, 20))\n",
    "    base_connectivity = np.random.randn(20, 20)\n",
    "\n",
    "    for frame in range(n_frames):\n",
    "        # Gradually strengthen/weaken connections\n",
    "        modulation = np.sin(2 * np.pi * frame / n_frames)\n",
    "        connectivity_frames[frame] = base_connectivity * (1 + 0.5 * modulation)\n",
    "\n",
    "        # Add plastic changes\n",
    "        if frame > 0:\n",
    "            plastic_change = np.random.normal(0, 0.01, (20, 20))\n",
    "            connectivity_frames[frame] += plastic_change\n",
    "\n",
    "    data['connectivity'] = connectivity_frames\n",
    "\n",
    "    # 5. Learning curve data\n",
    "    learning_steps = 100\n",
    "    performance = np.zeros(learning_steps)\n",
    "    error = np.zeros(learning_steps)\n",
    "\n",
    "    for step in range(learning_steps):\n",
    "        performance[step] = 1 - np.exp(-step / 20) + np.random.normal(0, 0.05)\n",
    "        error[step] = np.exp(-step / 15) + np.random.normal(0, 0.02)\n",
    "\n",
    "    data['learning'] = {'performance': performance, 'error': error}\n",
    "\n",
    "    return data\n",
    "\n",
    "\n",
    "# Generate all dynamic data\n",
    "print(\"Generating dynamic neural data...\")\n",
    "dynamic_data = generate_dynamic_data(n_frames=100)\n",
    "\n",
    "print(\"\\nGenerated data:\")\n",
    "print(f\"  Spike trains: {len(dynamic_data['spike_trains'])} frames\")\n",
    "print(f\"  Membrane potential: shape {dynamic_data['membrane_potential'].shape}\")\n",
    "print(f\"  2D activity: shape {dynamic_data['activity_2d'].shape}\")\n",
    "print(f\"  Connectivity: shape {dynamic_data['connectivity'].shape}\")\n",
    "print(f\"  Learning data: {len(dynamic_data['learning']['performance'])} steps\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Basic Animation Techniques <a id='basic'></a>\n",
    "\n",
    "Understanding the fundamentals of matplotlib animations before applying them to neural data."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Basic animation using FuncAnimation\n",
    "\n",
    "# Example 1: Simple oscillating sine wave\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
    "\n",
    "# Setup axes\n",
    "ax1.set_xlim(0, 2 * np.pi)\n",
    "ax1.set_ylim(-1.5, 1.5)\n",
    "ax1.set_xlabel('x')\n",
    "ax1.set_ylabel('y')\n",
    "ax1.set_title('Animated Sine Wave')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "\n",
    "# Create line object\n",
    "x = np.linspace(0, 2 * np.pi, 100)\n",
    "line, = ax1.plot([], [], 'b-', linewidth=2)\n",
    "\n",
    "\n",
    "# Animation function\n",
    "def animate_sine(frame):\n",
    "    y = np.sin(x + 0.1 * frame)\n",
    "    line.set_data(x, y)\n",
    "    return line,\n",
    "\n",
    "\n",
    "# Example 2: Growing scatter plot\n",
    "ax2.set_xlim(-2, 2)\n",
    "ax2.set_ylim(-2, 2)\n",
    "ax2.set_xlabel('x')\n",
    "ax2.set_ylabel('y')\n",
    "ax2.set_title('Animated Scatter Plot')\n",
    "ax2.grid(True, alpha=0.3)\n",
    "\n",
    "# Create scatter plot\n",
    "scat = ax2.scatter([], [], c=[], s=50, cmap='viridis', vmin=0, vmax=100)\n",
    "\n",
    "\n",
    "def animate_scatter(frame):\n",
    "    # Generate random points\n",
    "    n_points = min(frame + 1, 100)\n",
    "    x_data = np.random.randn(n_points) * (1 + frame / 100)\n",
    "    y_data = np.random.randn(n_points) * (1 + frame / 100)\n",
    "    colors = np.arange(n_points)\n",
    "\n",
    "    # Update scatter plot\n",
    "    data = np.c_[x_data, y_data]\n",
    "    scat.set_offsets(data)\n",
    "    scat.set_array(colors)\n",
    "    return scat,\n",
    "\n",
    "\n",
    "# Combine animations\n",
    "def animate_both(frame):\n",
    "    animate_sine(frame)\n",
    "    animate_scatter(frame)\n",
    "    return line, scat\n",
    "\n",
    "\n",
    "# Create animation\n",
    "anim = FuncAnimation(fig, animate_both, frames=100, interval=50, blit=True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Display animation in notebook\n",
    "display(HTML(anim.to_jshtml()))\n",
    "\n",
    "print(\"\\nBasic Animation Techniques:\")\n",
    "print(\"- FuncAnimation: Updates plot elements frame by frame\")\n",
    "print(\"- Blit=True: Optimizes rendering by only updating changed elements\")\n",
    "print(\"- Interval: Controls frame delay in milliseconds\")\n",
    "print(\"- Multiple subplots can be animated simultaneously\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Neural Activity Animation <a id='neural-activity'></a>\n",
    "\n",
    "Animating spike trains and neural population activity over time."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Animate spike raster plot with sliding window\n",
    "\n",
    "fig, axes = plt.subplots(2, 1, figsize=(12, 8))\n",
    "fig.suptitle('Neural Activity Animation', fontsize=14, fontweight='bold')\n",
    "\n",
    "# Parameters\n",
    "window_size = 2.0  # seconds\n",
    "n_neurons = 30\n",
    "total_duration = 10.0\n",
    "\n",
    "# Generate continuous spike trains\n",
    "all_spike_trains = []\n",
    "for neuron in range(n_neurons):\n",
    "    rate = np.random.uniform(5, 20)  # Hz\n",
    "    n_spikes = np.random.poisson(rate * total_duration)\n",
    "    spike_times = np.sort(np.random.uniform(0, total_duration, n_spikes))\n",
    "    all_spike_trains.append(spike_times)\n",
    "\n",
    "# Setup spike raster axis\n",
    "ax1 = axes[0]\n",
    "ax1.set_xlim(0, window_size)\n",
    "ax1.set_ylim(0, n_neurons)\n",
    "ax1.set_xlabel('Time (s)')\n",
    "ax1.set_ylabel('Neuron ID')\n",
    "ax1.set_title('Sliding Window Spike Raster')\n",
    "\n",
    "# Create scatter plot for spikes\n",
    "spike_scatter = ax1.scatter([], [], c='black', s=2, marker='|')\n",
    "\n",
    "# Setup population rate axis\n",
    "ax2 = axes[1]\n",
    "ax2.set_xlim(0, window_size)\n",
    "ax2.set_ylim(0, 30)\n",
    "ax2.set_xlabel('Time (s)')\n",
    "ax2.set_ylabel('Population Rate (Hz)')\n",
    "ax2.set_title('Population Firing Rate')\n",
    "ax2.grid(True, alpha=0.3)\n",
    "\n",
    "# Create line for population rate\n",
    "time_bins = np.linspace(0, window_size, 100)\n",
    "rate_line, = ax2.plot([], [], 'b-', linewidth=2)\n",
    "rate_fill = None\n",
    "\n",
    "\n",
    "def animate_neural_activity(frame):\n",
    "    global rate_fill\n",
    "\n",
    "    # Calculate window position\n",
    "    window_start = frame * 0.05  # Slide by 50ms per frame\n",
    "    window_end = window_start + window_size\n",
    "\n",
    "    # Get spikes in current window\n",
    "    spike_x = []\n",
    "    spike_y = []\n",
    "\n",
    "    for neuron_id, spike_train in enumerate(all_spike_trains):\n",
    "        window_spikes = spike_train[(spike_train >= window_start) &\n",
    "                                    (spike_train < window_end)]\n",
    "        # Shift to window coordinates\n",
    "        window_spikes_shifted = window_spikes - window_start\n",
    "        spike_x.extend(window_spikes_shifted)\n",
    "        spike_y.extend([neuron_id] * len(window_spikes_shifted))\n",
    "\n",
    "    # Update spike scatter\n",
    "    if spike_x:\n",
    "        spike_scatter.set_offsets(np.c_[spike_x, spike_y])\n",
    "\n",
    "    # Calculate population rate\n",
    "    counts, _ = np.histogram([s for s in spike_x], bins=time_bins)\n",
    "    rate = counts / (time_bins[1] - time_bins[0]) / n_neurons\n",
    "\n",
    "    # Smooth the rate\n",
    "    from scipy.ndimage import gaussian_filter1d\n",
    "    rate_smooth = gaussian_filter1d(rate, sigma=2)\n",
    "\n",
    "    # Update rate plot\n",
    "    rate_line.set_data(time_bins[:-1], rate_smooth)\n",
    "\n",
    "    # Update fill area (remove old and add new)\n",
    "    if rate_fill is not None:\n",
    "        rate_fill.remove()\n",
    "    rate_fill = ax2.fill_between(time_bins[:-1], 0, rate_smooth, alpha=0.3)\n",
    "\n",
    "    # Update time indicator\n",
    "    ax1.set_title(f'Sliding Window Spike Raster (t = {window_start:.1f}s)')\n",
    "\n",
    "    return spike_scatter, rate_line, rate_fill\n",
    "\n",
    "\n",
    "# Create animation\n",
    "n_frames = int((total_duration - window_size) / 0.05)\n",
    "anim = FuncAnimation(fig, animate_neural_activity, frames=n_frames,\n",
    "                     interval=50, blit=False)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Save as GIF\n",
    "print(\"Saving neural activity animation...\")\n",
    "anim.save(output_dir / 'neural_activity.gif', writer='pillow', fps=20)\n",
    "\n",
    "# Display in notebook\n",
    "display(HTML(anim.to_jshtml()))\n",
    "\n",
    "print(\"\\nNeural Activity Animation Features:\")\n",
    "print(\"- Sliding window through spike trains\")\n",
    "print(\"- Real-time population rate calculation\")\n",
    "print(\"- Synchronized multi-panel updates\")\n",
    "print(f\"- Animation saved to: {output_dir / 'neural_activity.gif'}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. 1D Signal Animation <a id='1d-animation'></a>\n",
    "\n",
    "Animating time series data such as membrane potentials and LFP signals."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Animate 1D signals with multiple channels\n",
    "\n",
    "# Generate multi-channel signals\n",
    "n_channels = 5\n",
    "signal_length = 1000\n",
    "time = np.linspace(0, 10, signal_length)\n",
    "\n",
    "# Create signals with different frequencies\n",
    "signals = np.zeros((n_channels, signal_length))\n",
    "for i in range(n_channels):\n",
    "    freq = 0.5 + i * 0.5  # Different frequencies\n",
    "    phase = np.random.uniform(0, 2 * np.pi)\n",
    "    signals[i] = np.sin(2 * np.pi * freq * time + phase) * (1 + 0.2 * i)\n",
    "    signals[i] += np.random.normal(0, 0.1, signal_length)\n",
    "\n",
    "# Setup figure\n",
    "fig, axes = plt.subplots(3, 1, figsize=(12, 10))\n",
    "fig.suptitle('1D Signal Animation', fontsize=14, fontweight='bold')\n",
    "\n",
    "# 1. Oscilloscope-style display\n",
    "ax1 = axes[0]\n",
    "ax1.set_xlim(0, 2)\n",
    "ax1.set_ylim(-3, 3)\n",
    "ax1.set_xlabel('Time (s)')\n",
    "ax1.set_ylabel('Amplitude')\n",
    "ax1.set_title('Oscilloscope View')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "\n",
    "lines_osc = []\n",
    "colors = plt.cm.viridis(np.linspace(0, 1, n_channels))\n",
    "for i in range(n_channels):\n",
    "    line, = ax1.plot([], [], color=colors[i], linewidth=1.5,\n",
    "                     label=f'Ch {i + 1}')\n",
    "    lines_osc.append(line)\n",
    "ax1.legend(loc='upper right')\n",
    "\n",
    "# 2. Waterfall plot\n",
    "ax2 = axes[1]\n",
    "ax2.set_xlim(0, 2)\n",
    "ax2.set_ylim(-1, n_channels * 2)\n",
    "ax2.set_xlabel('Time (s)')\n",
    "ax2.set_ylabel('Channel')\n",
    "ax2.set_title('Waterfall Display')\n",
    "\n",
    "lines_waterfall = []\n",
    "for i in range(n_channels):\n",
    "    line, = ax2.plot([], [], color=colors[i], linewidth=1.5)\n",
    "    lines_waterfall.append(line)\n",
    "\n",
    "# 3. Spectrogram-style heatmap\n",
    "ax3 = axes[2]\n",
    "ax3.set_xlabel('Time (s)')\n",
    "ax3.set_ylabel('Channel')\n",
    "ax3.set_title('Signal Intensity Heatmap')\n",
    "\n",
    "# Initialize heatmap data\n",
    "heatmap_data = np.zeros((n_channels, 100))\n",
    "im = ax3.imshow(heatmap_data, aspect='auto', cmap='hot',\n",
    "                extent=[0, 2, 0, n_channels], vmin=-2, vmax=2)\n",
    "plt.colorbar(im, ax=ax3, label='Amplitude')\n",
    "\n",
    "\n",
    "def animate_1d_signals(frame):\n",
    "    # Window parameters\n",
    "    window_size = 200  # samples\n",
    "    start_idx = frame * 5\n",
    "    end_idx = start_idx + window_size\n",
    "\n",
    "    if end_idx > signal_length:\n",
    "        return lines_osc + lines_waterfall + [im]\n",
    "\n",
    "    # Time window\n",
    "    time_window = time[start_idx:end_idx] - time[start_idx]\n",
    "\n",
    "    # Update oscilloscope\n",
    "    for i, line in enumerate(lines_osc):\n",
    "        line.set_data(time_window, signals[i, start_idx:end_idx])\n",
    "\n",
    "    # Update waterfall (with offset)\n",
    "    for i, line in enumerate(lines_waterfall):\n",
    "        offset_signal = signals[i, start_idx:end_idx] + i * 2\n",
    "        line.set_data(time_window, offset_signal)\n",
    "\n",
    "    # Update heatmap\n",
    "    heatmap_slice = signals[:, start_idx:end_idx:2]  # Downsample\n",
    "    im.set_array(heatmap_slice)\n",
    "\n",
    "    # Update titles with time\n",
    "    current_time = time[start_idx]\n",
    "    ax1.set_title(f'Oscilloscope View (t = {current_time:.1f}s)')\n",
    "\n",
    "    return lines_osc + lines_waterfall + [im]\n",
    "\n",
    "\n",
    "# Create animation\n",
    "n_frames = (signal_length - 200) // 5\n",
    "anim = FuncAnimation(fig, animate_1d_signals, frames=n_frames,\n",
    "                     interval=50, blit=True)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Display in notebook\n",
    "display(HTML(anim.to_jshtml()))\n",
    "\n",
    "print(\"\\n1D Signal Animation Techniques:\")\n",
    "print(\"- Oscilloscope: Real-time signal display\")\n",
    "print(\"- Waterfall: Multiple channels with vertical offset\")\n",
    "print(\"- Heatmap: Intensity representation over time\")\n",
    "print(\"- Synchronized visualization across different representations\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 2D Spatial Animation <a id='2d-animation'></a>\n",
    "\n",
    "Animating spatial patterns of neural activity, such as traveling waves and spreading activation."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Use braintools animator for 2D activity\n",
    "\n",
    "# Create figure for 2D animation\n",
    "fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
    "fig.suptitle('2D Spatial Activity Animation', fontsize=14, fontweight='bold')\n",
    "\n",
    "# Generate different types of 2D patterns\n",
    "n_frames = 100\n",
    "grid_size = 50\n",
    "\n",
    "# 1. Traveling wave\n",
    "wave_data = np.zeros((n_frames, grid_size, grid_size))\n",
    "for frame in range(n_frames):\n",
    "    x = np.linspace(-3, 3, grid_size)\n",
    "    y = np.linspace(-3, 3, grid_size)\n",
    "    X, Y = np.meshgrid(x, y)\n",
    "\n",
    "    # Wave position\n",
    "    wave_pos = -3 + 6 * frame / n_frames\n",
    "    wave_data[frame] = np.exp(-((X - wave_pos) ** 2 + Y ** 2) / 0.8)\n",
    "\n",
    "# 2. Rotating spiral\n",
    "spiral_data = np.zeros((n_frames, grid_size, grid_size))\n",
    "for frame in range(n_frames):\n",
    "    x = np.linspace(-2, 2, grid_size)\n",
    "    y = np.linspace(-2, 2, grid_size)\n",
    "    X, Y = np.meshgrid(x, y)\n",
    "\n",
    "    theta = np.arctan2(Y, X)\n",
    "    r = np.sqrt(X ** 2 + Y ** 2)\n",
    "    spiral_data[frame] = 0.5 * (1 + np.sin(3 * theta - r + 0.2 * frame))\n",
    "\n",
    "# 3. Spreading activation\n",
    "spread_data = np.zeros((n_frames, grid_size, grid_size))\n",
    "center = grid_size // 2\n",
    "for frame in range(n_frames):\n",
    "    x, y = np.ogrid[:grid_size, :grid_size]\n",
    "    radius = frame * 0.5\n",
    "    mask = (x - center) ** 2 + (y - center) ** 2 <= radius ** 2\n",
    "    spread_data[frame][mask] = np.exp(-frame / 20)\n",
    "\n",
    "# 4. Random hotspots\n",
    "hotspot_data = np.zeros((n_frames, grid_size, grid_size))\n",
    "n_hotspots = 5\n",
    "hotspot_centers = [(np.random.randint(10, 40), np.random.randint(10, 40))\n",
    "                   for _ in range(n_hotspots)]\n",
    "\n",
    "for frame in range(n_frames):\n",
    "    for cx, cy in hotspot_centers:\n",
    "        x, y = np.ogrid[:grid_size, :grid_size]\n",
    "        dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)\n",
    "        intensity = np.sin(2 * np.pi * frame / 20 + cx / 10) * 0.5 + 0.5\n",
    "        hotspot_data[frame] += intensity * np.exp(-dist ** 2 / 50)\n",
    "\n",
    "# Create animations using braintools animator\n",
    "ax1 = axes[0, 0]\n",
    "ax1.set_title('Traveling Wave')\n",
    "anim1 = braintools.visualize.animator(wave_data, fig, ax1, interval=50, cmap='viridis')\n",
    "\n",
    "ax2 = axes[0, 1]\n",
    "ax2.set_title('Rotating Spiral')\n",
    "anim2 = braintools.visualize.animator(spiral_data, fig, ax2, interval=50, cmap='plasma')\n",
    "\n",
    "ax3 = axes[1, 0]\n",
    "ax3.set_title('Spreading Activation')\n",
    "anim3 = braintools.visualize.animator(spread_data, fig, ax3, interval=50, cmap='hot')\n",
    "\n",
    "ax4 = axes[1, 1]\n",
    "ax4.set_title('Dynamic Hotspots')\n",
    "anim4 = braintools.visualize.animator(hotspot_data, fig, ax4, interval=50, cmap='jet')\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Display first animation\n",
    "display(HTML(anim1.to_jshtml()))\n",
    "\n",
    "print(\"\\n2D Spatial Animation Patterns:\")\n",
    "print(\"- Traveling waves: Common in cortical dynamics\")\n",
    "print(\"- Spiral waves: Found in cardiac and neural tissue\")\n",
    "print(\"- Spreading activation: Models seizure propagation\")\n",
    "print(\"- Dynamic hotspots: Represents changing activity centers\")\n",
    "print(\"\\nUsing braintools.visualize.animator for efficient 2D animation\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Network Dynamics Movies <a id='network-dynamics'></a>\n",
    "\n",
    "Creating movies that show evolving network structure and activity."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Animate network dynamics with changing connectivity and node activity\n",
    "\n",
    "# Network parameters\n",
    "n_nodes = 20\n",
    "n_frames = 100\n",
    "\n",
    "# Generate node positions (fixed)\n",
    "np.random.seed(42)\n",
    "theta = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)\n",
    "node_pos = np.column_stack([np.cos(theta), np.sin(theta)]) * 2\n",
    "\n",
    "# Add some noise to positions\n",
    "node_pos += np.random.normal(0, 0.1, node_pos.shape)\n",
    "\n",
    "# Generate time-varying connectivity\n",
    "base_connectivity = np.random.rand(n_nodes, n_nodes)\n",
    "base_connectivity = (base_connectivity + base_connectivity.T) / 2\n",
    "base_connectivity[base_connectivity < 0.7] = 0\n",
    "np.fill_diagonal(base_connectivity, 0)\n",
    "\n",
    "# Generate time-varying node activity\n",
    "node_activity = np.zeros((n_frames, n_nodes))\n",
    "for frame in range(n_frames):\n",
    "    # Oscillating activity\n",
    "    phase = 2 * np.pi * frame / 30\n",
    "    for i in range(n_nodes):\n",
    "        node_activity[frame, i] = 0.5 + 0.5 * np.sin(phase + i * np.pi / 5)\n",
    "\n",
    "# Setup figure\n",
    "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
    "fig.suptitle('Network Dynamics Animation', fontsize=14, fontweight='bold')\n",
    "\n",
    "# Network visualization axis\n",
    "ax1 = axes[0]\n",
    "ax1.set_xlim(-3, 3)\n",
    "ax1.set_ylim(-3, 3)\n",
    "ax1.set_aspect('equal')\n",
    "ax1.set_title('Network Structure')\n",
    "ax1.axis('off')\n",
    "\n",
    "# Connectivity matrix axis\n",
    "ax2 = axes[1]\n",
    "ax2.set_title('Connectivity Matrix')\n",
    "im = ax2.imshow(base_connectivity, cmap='RdBu_r', vmin=-1, vmax=1)\n",
    "plt.colorbar(im, ax=ax2, label='Weight')\n",
    "\n",
    "# Initialize network plot elements\n",
    "nodes = ax1.scatter(node_pos[:, 0], node_pos[:, 1],\n",
    "                    s=200, c=node_activity[0], cmap='hot',\n",
    "                    vmin=0, vmax=1, edgecolors='black', linewidths=1)\n",
    "\n",
    "# Create edge lines\n",
    "edges = []\n",
    "for i in range(n_nodes):\n",
    "    for j in range(i + 1, n_nodes):\n",
    "        if base_connectivity[i, j] > 0:\n",
    "            line, = ax1.plot([node_pos[i, 0], node_pos[j, 0]],\n",
    "                             [node_pos[i, 1], node_pos[j, 1]],\n",
    "                             'k-', alpha=0.3, linewidth=1)\n",
    "            edges.append((line, i, j))\n",
    "\n",
    "\n",
    "def animate_network(frame):\n",
    "    # Update node colors based on activity\n",
    "    nodes.set_array(node_activity[frame])\n",
    "\n",
    "    # Update edge properties based on connectivity strength\n",
    "    connectivity_frame = base_connectivity * (1 + 0.5 * np.sin(2 * np.pi * frame / 50))\n",
    "\n",
    "    for line, i, j in edges:\n",
    "        strength = abs(connectivity_frame[i, j])\n",
    "        line.set_alpha(min(strength, 1.0))\n",
    "        line.set_linewidth(strength * 3)\n",
    "\n",
    "        # Color based on correlation between nodes\n",
    "        if node_activity[frame, i] * node_activity[frame, j] > 0.5:\n",
    "            line.set_color('red')\n",
    "        else:\n",
    "            line.set_color('blue')\n",
    "\n",
    "    # Update connectivity matrix\n",
    "    im.set_array(connectivity_frame)\n",
    "\n",
    "    # Update title with time\n",
    "    ax1.set_title(f'Network Structure (t = {frame:.0f})')\n",
    "\n",
    "    return [nodes, im] + [e[0] for e in edges]\n",
    "\n",
    "\n",
    "# Create animation\n",
    "anim = FuncAnimation(fig, animate_network, frames=n_frames,\n",
    "                     interval=100, blit=False)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Save as MP4 if ffmpeg is available\n",
    "try:\n",
    "    Writer = FFMpegWriter(fps=10, bitrate=1800)\n",
    "    anim.save(output_dir / 'network_dynamics.mp4', writer=Writer)\n",
    "    print(f\"Network dynamics movie saved to: {output_dir / 'network_dynamics.mp4'}\")\n",
    "except:\n",
    "    print(\"FFMpeg not available, saving as GIF instead\")\n",
    "    anim.save(output_dir / 'network_dynamics.gif', writer='pillow', fps=10)\n",
    "\n",
    "# Display in notebook\n",
    "display(HTML(anim.to_jshtml()))\n",
    "\n",
    "print(\"\\nNetwork Dynamics Features:\")\n",
    "print(\"- Node activity represented by color intensity\")\n",
    "print(\"- Edge strength shown by line width and opacity\")\n",
    "print(\"- Edge color indicates correlation between connected nodes\")\n",
    "print(\"- Synchronized matrix view shows connectivity changes\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Learning Process Visualization <a id='learning'></a>\n",
    "\n",
    "Time-lapse visualization of learning dynamics, weight evolution, and performance metrics."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Animate learning process with multiple metrics\n",
    "\n",
    "# Simulate learning data\n",
    "n_epochs = 100\n",
    "n_neurons = 10\n",
    "\n",
    "# Performance metrics\n",
    "accuracy = np.zeros(n_epochs)\n",
    "loss = np.zeros(n_epochs)\n",
    "\n",
    "# Weight evolution\n",
    "weights = np.random.randn(n_epochs, n_neurons, n_neurons) * 0.1\n",
    "\n",
    "# Simulate learning\n",
    "for epoch in range(n_epochs):\n",
    "    # Learning curves\n",
    "    accuracy[epoch] = 1 - np.exp(-epoch / 20) + np.random.normal(0, 0.02)\n",
    "    loss[epoch] = 2 * np.exp(-epoch / 15) + np.random.normal(0, 0.05)\n",
    "\n",
    "    # Weight updates (Hebbian-like)\n",
    "    if epoch > 0:\n",
    "        weights[epoch] = weights[epoch - 1] + np.random.randn(n_neurons, n_neurons) * 0.01\n",
    "        weights[epoch] *= 0.99  # Decay\n",
    "\n",
    "# Setup figure\n",
    "fig = plt.figure(figsize=(15, 10))\n",
    "gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)\n",
    "\n",
    "# Axes\n",
    "ax1 = fig.add_subplot(gs[0, :])\n",
    "ax2 = fig.add_subplot(gs[1, 0])\n",
    "ax3 = fig.add_subplot(gs[1, 1])\n",
    "ax4 = fig.add_subplot(gs[1, 2])\n",
    "ax5 = fig.add_subplot(gs[2, :])\n",
    "\n",
    "fig.suptitle('Learning Process Animation', fontsize=14, fontweight='bold')\n",
    "\n",
    "# 1. Learning curves\n",
    "ax1.set_xlim(0, n_epochs)\n",
    "ax1.set_ylim(0, 1.2)\n",
    "ax1.set_xlabel('Epoch')\n",
    "ax1.set_ylabel('Performance')\n",
    "ax1.set_title('Learning Curves')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "\n",
    "acc_line, = ax1.plot([], [], 'g-', linewidth=2, label='Accuracy')\n",
    "loss_line, = ax1.plot([], [], 'r-', linewidth=2, label='Loss (scaled)')\n",
    "ax1.legend(loc='right')\n",
    "\n",
    "# 2. Weight matrix\n",
    "im_weights = ax2.imshow(weights[0], cmap='RdBu_r', vmin=-1, vmax=1)\n",
    "ax2.set_title('Weight Matrix')\n",
    "ax2.set_xlabel('Post')\n",
    "ax2.set_ylabel('Pre')\n",
    "\n",
    "# 3. Weight histogram\n",
    "ax3.set_xlim(-2, 2)\n",
    "ax3.set_ylim(0, 30)\n",
    "ax3.set_xlabel('Weight Value')\n",
    "ax3.set_ylabel('Count')\n",
    "ax3.set_title('Weight Distribution')\n",
    "n_bins = 30\n",
    "counts, bins = np.histogram(weights[0].flatten(), bins=n_bins)\n",
    "hist_bars = ax3.bar(bins[:-1], counts, width=bins[1] - bins[0],\n",
    "                    color='blue', alpha=0.7)\n",
    "\n",
    "# 4. Network activity\n",
    "activity = np.random.rand(n_neurons)\n",
    "activity_bars = ax4.bar(range(n_neurons), activity, color='orange')\n",
    "ax4.set_xlim(-0.5, n_neurons - 0.5)\n",
    "ax4.set_ylim(0, 1)\n",
    "ax4.set_xlabel('Neuron')\n",
    "ax4.set_ylabel('Activity')\n",
    "ax4.set_title('Neural Activity')\n",
    "\n",
    "# 5. Feature importance evolution\n",
    "feature_importance = np.random.rand(n_epochs, 5)\n",
    "for i in range(1, n_epochs):\n",
    "    feature_importance[i] = feature_importance[i - 1] * 0.9 + np.random.rand(5) * 0.1\n",
    "\n",
    "ax5.set_xlim(0, n_epochs)\n",
    "ax5.set_ylim(0, 1)\n",
    "ax5.set_xlabel('Epoch')\n",
    "ax5.set_ylabel('Importance')\n",
    "ax5.set_title('Feature Importance Evolution')\n",
    "\n",
    "feature_lines = []\n",
    "for i in range(5):\n",
    "    line, = ax5.plot([], [], linewidth=2, label=f'Feature {i + 1}')\n",
    "    feature_lines.append(line)\n",
    "ax5.legend(loc='right')\n",
    "\n",
    "\n",
    "def animate_learning(frame):\n",
    "    # Update learning curves\n",
    "    epochs_so_far = np.arange(frame + 1)\n",
    "    acc_line.set_data(epochs_so_far, accuracy[:frame + 1])\n",
    "    loss_line.set_data(epochs_so_far, loss[:frame + 1] / 2)  # Scale for display\n",
    "\n",
    "    # Update weight matrix\n",
    "    im_weights.set_array(weights[frame])\n",
    "\n",
    "    # Update weight histogram\n",
    "    counts, _ = np.histogram(weights[frame].flatten(), bins=bins)\n",
    "    for bar, count in zip(hist_bars, counts):\n",
    "        bar.set_height(count)\n",
    "\n",
    "    # Update activity (simulate changing patterns)\n",
    "    new_activity = np.abs(np.sum(weights[frame], axis=0))\n",
    "    new_activity = new_activity / np.max(new_activity) if np.max(new_activity) > 0 else new_activity\n",
    "    for bar, act in zip(activity_bars, new_activity):\n",
    "        bar.set_height(act)\n",
    "\n",
    "    # Update feature importance\n",
    "    for i, line in enumerate(feature_lines):\n",
    "        line.set_data(epochs_so_far, feature_importance[:frame + 1, i])\n",
    "\n",
    "    # Update main title\n",
    "    fig.suptitle(f'Learning Process Animation (Epoch {frame})',\n",
    "                 fontsize=14, fontweight='bold')\n",
    "\n",
    "    return [acc_line, loss_line, im_weights] + list(hist_bars) + \\\n",
    "        list(activity_bars) + feature_lines\n",
    "\n",
    "\n",
    "# Create animation\n",
    "anim = FuncAnimation(fig, animate_learning, frames=n_epochs,\n",
    "                     interval=100, blit=False)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Save animation\n",
    "anim.save(output_dir / 'learning_process.gif', writer='pillow', fps=10)\n",
    "\n",
    "# Display in notebook\n",
    "display(HTML(anim.to_jshtml()))\n",
    "\n",
    "print(\"\\nLearning Process Visualization:\")\n",
    "print(\"- Performance metrics tracked over epochs\")\n",
    "print(\"- Weight matrix evolution shows plasticity\")\n",
    "print(\"- Weight distribution changes during learning\")\n",
    "print(\"- Neural activity patterns evolve with training\")\n",
    "print(\"- Feature importance shows what the network learns\")\n",
    "print(f\"- Animation saved to: {output_dir / 'learning_process.gif'}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Dynamic Connectivity <a id='connectivity'></a>\n",
    "\n",
    "Visualizing how connectivity patterns change over time, including synaptic plasticity and network reorganization."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Animate dynamic connectivity with plasticity rules\n",
    "\n",
    "# Network parameters\n",
    "n_neurons = 15\n",
    "n_frames = 150\n",
    "\n",
    "# Initialize connectivity\n",
    "connectivity = np.random.randn(n_frames, n_neurons, n_neurons) * 0.1\n",
    "connectivity[0] = (connectivity[0] + connectivity[0].T) / 2\n",
    "\n",
    "# Simulate STDP-like plasticity\n",
    "for t in range(1, n_frames):\n",
    "    # Copy previous connectivity\n",
    "    connectivity[t] = connectivity[t - 1].copy()\n",
    "\n",
    "    # Random spike times for STDP\n",
    "    pre_spikes = np.random.rand(n_neurons) < 0.1\n",
    "    post_spikes = np.random.rand(n_neurons) < 0.1\n",
    "\n",
    "    # STDP update\n",
    "    for i in range(n_neurons):\n",
    "        for j in range(n_neurons):\n",
    "            if i != j:\n",
    "                if pre_spikes[i] and post_spikes[j]:\n",
    "                    # Potentiation\n",
    "                    connectivity[t, i, j] += 0.01\n",
    "                elif post_spikes[i] and pre_spikes[j]:\n",
    "                    # Depression\n",
    "                    connectivity[t, i, j] -= 0.005\n",
    "\n",
    "    # Add noise\n",
    "    connectivity[t] += np.random.normal(0, 0.001, (n_neurons, n_neurons))\n",
    "\n",
    "    # Bounds\n",
    "    connectivity[t] = np.clip(connectivity[t], -1, 1)\n",
    "\n",
    "    # Maintain symmetry for display\n",
    "    connectivity[t] = (connectivity[t] + connectivity[t].T) / 2\n",
    "\n",
    "# Setup figure\n",
    "fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
    "fig.suptitle('Dynamic Connectivity Patterns', fontsize=14, fontweight='bold')\n",
    "\n",
    "# 1. Connectivity matrix\n",
    "ax1 = axes[0, 0]\n",
    "im1 = ax1.imshow(connectivity[0], cmap='RdBu_r', vmin=-0.5, vmax=0.5)\n",
    "ax1.set_title('Connectivity Matrix')\n",
    "ax1.set_xlabel('Post-synaptic')\n",
    "ax1.set_ylabel('Pre-synaptic')\n",
    "plt.colorbar(im1, ax=ax1, label='Weight')\n",
    "\n",
    "# 2. Connection strength distribution\n",
    "ax2 = axes[0, 1]\n",
    "ax2.set_xlim(-0.5, 0.5)\n",
    "ax2.set_ylim(0, 50)\n",
    "ax2.set_xlabel('Connection Strength')\n",
    "ax2.set_ylabel('Count')\n",
    "ax2.set_title('Weight Distribution')\n",
    "bins = np.linspace(-0.5, 0.5, 30)\n",
    "counts, _ = np.histogram(connectivity[0].flatten(), bins=bins)\n",
    "bars = ax2.bar(bins[:-1], counts, width=bins[1] - bins[0], alpha=0.7)\n",
    "\n",
    "# 3. Graph visualization\n",
    "ax3 = axes[1, 0]\n",
    "ax3.set_xlim(-2, 2)\n",
    "ax3.set_ylim(-2, 2)\n",
    "ax3.set_aspect('equal')\n",
    "ax3.set_title('Network Graph')\n",
    "ax3.axis('off')\n",
    "\n",
    "# Node positions (circular layout)\n",
    "theta = np.linspace(0, 2 * np.pi, n_neurons, endpoint=False)\n",
    "node_x = 1.5 * np.cos(theta)\n",
    "node_y = 1.5 * np.sin(theta)\n",
    "\n",
    "# Draw nodes\n",
    "nodes = ax3.scatter(node_x, node_y, s=200, c='lightblue',\n",
    "                    edgecolors='black', linewidths=1, zorder=3)\n",
    "\n",
    "# Initialize edges\n",
    "edge_lines = []\n",
    "for i in range(n_neurons):\n",
    "    for j in range(i + 1, n_neurons):\n",
    "        line, = ax3.plot([node_x[i], node_x[j]],\n",
    "                         [node_y[i], node_y[j]],\n",
    "                         'k-', alpha=0, linewidth=1, zorder=1)\n",
    "        edge_lines.append((line, i, j))\n",
    "\n",
    "# 4. Connectivity metrics\n",
    "ax4 = axes[1, 1]\n",
    "ax4.set_xlim(0, n_frames)\n",
    "ax4.set_ylim(-0.1, 0.5)\n",
    "ax4.set_xlabel('Time')\n",
    "ax4.set_ylabel('Metric Value')\n",
    "ax4.set_title('Network Metrics')\n",
    "ax4.grid(True, alpha=0.3)\n",
    "\n",
    "# Calculate metrics\n",
    "mean_strength = np.mean(np.abs(connectivity), axis=(1, 2))\n",
    "sparsity = np.mean(np.abs(connectivity) > 0.1, axis=(1, 2))\n",
    "\n",
    "mean_line, = ax4.plot([], [], 'b-', label='Mean Strength')\n",
    "sparsity_line, = ax4.plot([], [], 'r-', label='Sparsity')\n",
    "ax4.legend()\n",
    "\n",
    "\n",
    "def animate_connectivity(frame):\n",
    "    # Update connectivity matrix\n",
    "    im1.set_array(connectivity[frame])\n",
    "\n",
    "    # Update histogram\n",
    "    counts, _ = np.histogram(connectivity[frame].flatten(), bins=bins)\n",
    "    for bar, count in zip(bars, counts):\n",
    "        bar.set_height(count)\n",
    "\n",
    "    # Update graph edges\n",
    "    for line, i, j in edge_lines:\n",
    "        weight = connectivity[frame, i, j]\n",
    "        if abs(weight) > 0.1:  # Threshold for visibility\n",
    "            line.set_alpha(min(abs(weight) * 2, 1.0))\n",
    "            line.set_linewidth(abs(weight) * 5)\n",
    "            if weight > 0:\n",
    "                line.set_color('red')\n",
    "            else:\n",
    "                line.set_color('blue')\n",
    "        else:\n",
    "            line.set_alpha(0)\n",
    "\n",
    "    # Update metrics\n",
    "    frames_so_far = np.arange(frame + 1)\n",
    "    mean_line.set_data(frames_so_far, mean_strength[:frame + 1])\n",
    "    sparsity_line.set_data(frames_so_far, sparsity[:frame + 1])\n",
    "\n",
    "    # Update title\n",
    "    fig.suptitle(f'Dynamic Connectivity Patterns (t = {frame})',\n",
    "                 fontsize=14, fontweight='bold')\n",
    "\n",
    "    return [im1] + list(bars) + [e[0] for e in edge_lines] + \\\n",
    "        [mean_line, sparsity_line]\n",
    "\n",
    "\n",
    "# Create animation\n",
    "anim = FuncAnimation(fig, animate_connectivity, frames=n_frames,\n",
    "                     interval=50, blit=False)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Display in notebook\n",
    "display(HTML(anim.to_jshtml()))\n",
    "\n",
    "print(\"\\nDynamic Connectivity Features:\")\n",
    "print(\"- STDP-like plasticity rules applied\")\n",
    "print(\"- Weight distribution evolves over time\")\n",
    "print(\"- Graph visualization shows strong connections\")\n",
    "print(\"- Network metrics track global changes\")\n",
    "print(\"- Red edges: excitatory, Blue edges: inhibitory\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10. Export and Optimization <a id='export'></a>\n",
    "\n",
    "Best practices for exporting animations and optimizing performance for large datasets."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "import time\n",
    "\n",
    "# Demonstration of export options and performance optimization\n",
    "\n",
    "print(\"Animation Export Options and Optimization\")\n",
    "print(\"=\" * 50)\n",
    "\n",
    "# Create a simple test animation\n",
    "fig, ax = plt.subplots(figsize=(6, 4))\n",
    "ax.set_xlim(0, 10)\n",
    "ax.set_ylim(-1, 1)\n",
    "line, = ax.plot([], [], 'b-')\n",
    "\n",
    "\n",
    "def simple_animate(frame):\n",
    "    x = np.linspace(0, 10, 100)\n",
    "    y = np.sin(x + 0.1 * frame)\n",
    "    line.set_data(x, y)\n",
    "    return line,\n",
    "\n",
    "\n",
    "anim = FuncAnimation(fig, simple_animate, frames=50, interval=50, blit=True)\n",
    "\n",
    "# 1. Export formats\n",
    "print(\"\\n1. Export Formats:\")\n",
    "print(\"-\" * 30)\n",
    "\n",
    "# GIF export\n",
    "try:\n",
    "    writer = PillowWriter(fps=20)\n",
    "    anim.save(output_dir / 'test.gif', writer=writer)\n",
    "    print(\"✓ GIF: Saved using PillowWriter\")\n",
    "    print(\"  - Good for: Small animations, web display\")\n",
    "    print(\"  - File size: Medium\")\n",
    "    print(\"  - Quality: Good for simple graphics\")\n",
    "except Exception as e:\n",
    "    print(f\"✗ GIF export failed: {e}\")\n",
    "\n",
    "# MP4 export\n",
    "try:\n",
    "    writer = FFMpegWriter(fps=30, bitrate=1800)\n",
    "    anim.save(output_dir / 'test.mp4', writer=writer)\n",
    "    print(\"\\n✓ MP4: Saved using FFMpegWriter\")\n",
    "    print(\"  - Good for: Large animations, presentations\")\n",
    "    print(\"  - File size: Small (compressed)\")\n",
    "    print(\"  - Quality: Excellent\")\n",
    "except:\n",
    "    print(\"\\n✗ MP4: FFMpeg not installed\")\n",
    "    print(\"  Install with: conda install ffmpeg\")\n",
    "\n",
    "# HTML export\n",
    "html_str = anim.to_jshtml()\n",
    "with open(output_dir / 'test.html', 'w') as f:\n",
    "    f.write(html_str)\n",
    "print(\"\\n✓ HTML: Saved with JavaScript controls\")\n",
    "print(\"  - Good for: Interactive viewing, notebooks\")\n",
    "print(\"  - File size: Large (embedded data)\")\n",
    "print(\"  - Quality: Excellent\")\n",
    "\n",
    "plt.close(fig)\n",
    "\n",
    "# 2. Performance optimization techniques\n",
    "print(\"\\n2. Performance Optimization:\")\n",
    "print(\"-\" * 30)\n",
    "\n",
    "# Technique 1: Blitting\n",
    "print(\"\\na) Blitting:\")\n",
    "print(\"   - Use blit=True in FuncAnimation\")\n",
    "print(\"   - Only redraws changed elements\")\n",
    "print(\"   - 2-5x speed improvement\")\n",
    "\n",
    "# Technique 2: Data decimation\n",
    "print(\"\\nb) Data Decimation:\")\n",
    "print(\"   Example: Downsample large datasets\")\n",
    "\n",
    "# Generate large dataset\n",
    "large_data = np.random.randn(1000, 1000)\n",
    "print(f\"   Original size: {large_data.shape}\")\n",
    "\n",
    "# Decimate\n",
    "decimated_data = large_data[::5, ::5]  # Take every 5th point\n",
    "print(f\"   Decimated size: {decimated_data.shape}\")\n",
    "print(f\"   Memory reduction: {(1 - decimated_data.size / large_data.size) * 100:.1f}%\")\n",
    "\n",
    "# Technique 3: Frame caching\n",
    "print(\"\\nc) Frame Caching:\")\n",
    "print(\"   - Pre-compute expensive operations\")\n",
    "print(\"   - Store results in memory\")\n",
    "print(\"   - Trade memory for speed\")\n",
    "\n",
    "# Example: Pre-compute frames\n",
    "n_frames = 100\n",
    "frame_cache = []\n",
    "print(\"   Pre-computing frames...\", end=\"\")\n",
    "start_time = time.time()\n",
    "\n",
    "for i in range(n_frames):\n",
    "    # Expensive computation\n",
    "    frame_data = np.sin(np.linspace(0, 10, 1000) + i * 0.1)\n",
    "    frame_cache.append(frame_data)\n",
    "\n",
    "cache_time = time.time() - start_time\n",
    "print(f\" Done in {cache_time:.2f}s\")\n",
    "\n",
    "# Technique 4: Reduce artists\n",
    "print(\"\\nd) Reduce Number of Artists:\")\n",
    "print(\"   - Combine multiple lines into LineCollection\")\n",
    "print(\"   - Use single scatter plot instead of multiple\")\n",
    "print(\"   - Batch updates when possible\")\n",
    "\n",
    "# 3. Memory management\n",
    "print(\"\\n3. Memory Management:\")\n",
    "print(\"-\" * 30)\n",
    "print(\"- Clear figure: plt.close(fig)\")\n",
    "print(\"- Limit frame buffer: cache_frame_data=False\")\n",
    "print(\"- Use generators for data streaming\")\n",
    "print(\"- Delete large arrays after use\")\n",
    "\n",
    "# 4. Format comparison\n",
    "print(\"\\n4. Format Comparison Table:\")\n",
    "print(\"-\" * 30)\n",
    "print(\"Format | Quality | Size | Speed | Use Case\")\n",
    "print(\"-------|---------|------|-------|----------\")\n",
    "print(\"GIF    | Medium  | Med  | Fast  | Web, Email\")\n",
    "print(\"MP4    | High    | Small| Medium| Presentation\")\n",
    "print(\"AVI    | High    | Large| Slow  | Editing\")\n",
    "print(\"HTML   | High    | Large| Fast  | Interactive\")\n",
    "print(\"Frames | Highest | Huge | Slow  | Post-process\")\n",
    "\n",
    "print(\"\\n✓ Export examples saved to:\", output_dir.absolute())"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 11. Advanced Techniques <a id='advanced'></a>\n",
    "\n",
    "Advanced animation techniques including interactive controls and custom animations."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Advanced animation with custom controls\n",
    "\n",
    "from matplotlib.widgets import Slider, Button\n",
    "\n",
    "# Create figure with controls\n",
    "fig = plt.figure(figsize=(12, 8))\n",
    "\n",
    "# Main plot\n",
    "ax_main = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3)\n",
    "\n",
    "# Control axes\n",
    "ax_speed = plt.subplot2grid((4, 3), (3, 0), colspan=2)\n",
    "ax_button = plt.subplot2grid((4, 3), (3, 2))\n",
    "\n",
    "# Generate complex data\n",
    "t = np.linspace(0, 10, 1000)\n",
    "frequencies = [0.5, 1.0, 2.0]\n",
    "amplitudes = [1.0, 0.5, 0.3]\n",
    "\n",
    "# Setup main plot\n",
    "ax_main.set_xlim(0, 10)\n",
    "ax_main.set_ylim(-2, 2)\n",
    "ax_main.set_xlabel('Time')\n",
    "ax_main.set_ylabel('Signal')\n",
    "ax_main.set_title('Interactive Animation Control')\n",
    "ax_main.grid(True, alpha=0.3)\n",
    "\n",
    "# Create lines\n",
    "lines = []\n",
    "for i, (freq, amp) in enumerate(zip(frequencies, amplitudes)):\n",
    "    line, = ax_main.plot([], [], linewidth=2,\n",
    "                         label=f'f={freq}Hz')\n",
    "    lines.append(line)\n",
    "ax_main.legend()\n",
    "\n",
    "# Speed slider\n",
    "speed_slider = Slider(ax_speed, 'Speed', 0.1, 5.0,\n",
    "                      valinit=1.0, valstep=0.1)\n",
    "\n",
    "# Play/Pause button\n",
    "button = Button(ax_button, 'Pause')\n",
    "is_paused = False\n",
    "\n",
    "# Animation state\n",
    "phase = 0\n",
    "speed_multiplier = 1.0\n",
    "\n",
    "\n",
    "def update_speed(val):\n",
    "    global speed_multiplier\n",
    "    speed_multiplier = speed_slider.val\n",
    "\n",
    "\n",
    "def toggle_pause(event):\n",
    "    global is_paused\n",
    "    is_paused = not is_paused\n",
    "    button.label.set_text('Play' if is_paused else 'Pause')\n",
    "\n",
    "\n",
    "speed_slider.on_changed(update_speed)\n",
    "button.on_clicked(toggle_pause)\n",
    "\n",
    "\n",
    "def animate_advanced(frame):\n",
    "    global phase\n",
    "\n",
    "    if not is_paused:\n",
    "        phase += 0.1 * speed_multiplier\n",
    "\n",
    "    for i, (line, freq, amp) in enumerate(zip(lines, frequencies, amplitudes)):\n",
    "        y = amp * np.sin(2 * np.pi * freq * t + phase + i * np.pi / 3)\n",
    "        line.set_data(t, y)\n",
    "\n",
    "    ax_main.set_title(f'Interactive Animation (Phase: {phase:.1f})')\n",
    "\n",
    "    return lines\n",
    "\n",
    "\n",
    "# Create animation\n",
    "anim = FuncAnimation(fig, animate_advanced, interval=50, blit=True)\n",
    "\n",
    "plt.show()\n",
    "\n",
    "print(\"Advanced Animation Features:\")\n",
    "print(\"- Interactive speed control with slider\")\n",
    "print(\"- Play/Pause functionality\")\n",
    "print(\"- Real-time parameter updates\")\n",
    "print(\"- Custom widget integration\")\n",
    "print(\"\\nNote: Interactive controls work best in standalone window\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary and Best Practices\n",
    "\n",
    "This tutorial has covered comprehensive animation techniques for neural data visualization:\n",
    "\n",
    "1. **Basic Animation**\n",
    "   - FuncAnimation for frame-by-frame updates\n",
    "   - ArtistAnimation for pre-computed frames\n",
    "   - Blitting for performance optimization\n",
    "\n",
    "2. **Neural-Specific Animations**\n",
    "   - Spike raster sliding windows\n",
    "   - Population activity dynamics\n",
    "   - Membrane potential oscillations\n",
    "\n",
    "3. **Spatial Animations**\n",
    "   - 2D activity patterns\n",
    "   - Traveling waves and spirals\n",
    "   - Spreading activation\n",
    "\n",
    "4. **Network Dynamics**\n",
    "   - Evolving connectivity\n",
    "   - Node activity changes\n",
    "   - Graph visualization\n",
    "\n",
    "5. **Learning Visualization**\n",
    "   - Performance metrics over time\n",
    "   - Weight evolution\n",
    "   - Feature importance dynamics\n",
    "\n",
    "\n",
    "Export Options:\n",
    "\n",
    "| Format | Writer | Best For | Pros | Cons |\n",
    "|--------|--------|----------|------|------|\n",
    "| GIF | Pillow | Web, Email | Universal support | Large files |\n",
    "| MP4 | FFMpeg | Presentations | Small, high quality | Requires FFMpeg |\n",
    "| HTML | JavaScript | Notebooks | Interactive | Large, needs browser |\n",
    "| Frames | ImageMagick | Post-processing | Full control | Very large |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
