Hodgkin-Huxley Neuron Model#
This tutorial demonstrates how to implement biologically realistic neuron models using BrainState. The Hodgkin-Huxley (HH) model is one of the most important models in computational neuroscience, describing action potential generation in neurons through ion channel dynamics.
Learning Objectives#
By the end of this tutorial, you will:
Understand the Hodgkin-Huxley neuron model
Use BrainUnit for physical units and dimensional analysis
Implement biophysically detailed neuron dynamics
Simulate and visualize neuron spiking activity
Use BrainState’s
Dynamicsclass for continuous-time models
The Hodgkin-Huxley Model#
The HH model describes the electrical activity of neurons through:
Membrane Voltage (
V): The electrical potential across the cell membraneIon Channels:
Sodium (Na⁺) channels with activation (
m) and inactivation (h) gatesPotassium (K⁺) channels with activation (
n) gatesLeak channels
Governing Equations:
C dV/dt = -I_Na - I_K - I_leak + I_ext
I_Na = gNa * m³ * h * (V - ENa)
I_K = gK * n⁴ * (V - EK)
I_leak = gL * (V - EL)
Each gate variable (m, h, n) follows:
dx/dt = α(V) * (1 - x) - β(V) * x
Setup and Imports#
import brainunit as u
import jax.numpy as jnp
import matplotlib.pyplot as plt
import brainpy
import brainstate
# Set random seed
brainstate.random.seed(42)
Implementing the HH Model#
We’ll use BrainState’s nn.Dynamics class for continuous-time dynamics and BrainUnit for physical units:
class HH(brainstate.nn.Dynamics):
"""Hodgkin-Huxley neuron model.
A biophysically detailed model of action potential generation.
"""
def __init__(
self,
in_size,
ENa=50. * u.mV, # Sodium reversal potential
gNa=120. * u.mS / u.cm ** 2, # Sodium conductance
EK=-77. * u.mV, # Potassium reversal potential
gK=36. * u.mS / u.cm ** 2, # Potassium conductance
EL=-54.387 * u.mV, # Leak reversal potential
gL=0.03 * u.mS / u.cm ** 2, # Leak conductance
V_th=20. * u.mV, # Spike threshold
C=1.0 * u.uF / u.cm ** 2 # Membrane capacitance
):
super().__init__(in_size)
# Store parameters
self.ENa = ENa
self.EK = EK
self.EL = EL
self.gNa = gNa
self.gK = gK
self.gL = gL
self.C = C
self.V_th = V_th
# Sodium channel activation (m)
def m_alpha(self, V):
return 1. / u.math.exprel(-(V / u.mV + 40) / 10)
def m_beta(self, V):
return 4.0 * jnp.exp(-(V / u.mV + 65) / 18)
def m_inf(self, V):
return self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V))
def dm(self, m, t, V):
return (self.m_alpha(V) * (1 - m) - self.m_beta(V) * m) / u.ms
# Sodium channel inactivation (h)
def h_alpha(self, V):
return 0.07 * jnp.exp(-(V / u.mV + 65) / 20.)
def h_beta(self, V):
return 1 / (1 + jnp.exp(-(V / u.mV + 35) / 10))
def h_inf(self, V):
return self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V))
def dh(self, h, t, V):
return (self.h_alpha(V) * (1 - h) - self.h_beta(V) * h) / u.ms
# Potassium channel activation (n)
def n_alpha(self, V):
return 0.1 / u.math.exprel(-(V / u.mV + 55) / 10)
def n_beta(self, V):
return 0.125 * jnp.exp(-(V / u.mV + 65) / 80)
def n_inf(self, V):
return self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V))
def dn(self, n, t, V):
return (self.n_alpha(V) * (1 - n) - self.n_beta(V) * n) / u.ms
def init_state(self, batch_size=None):
"""Initialize state variables at rest."""
self.V = brainstate.HiddenState(
jnp.ones(self.varshape, brainstate.environ.dftype()) * -65. * u.mV
)
self.m = brainstate.HiddenState(self.m_inf(self.V.value))
self.h = brainstate.HiddenState(self.h_inf(self.V.value))
self.n = brainstate.HiddenState(self.n_inf(self.V.value))
def dV(self, V, t, m, h, n, I):
"""Membrane voltage dynamics."""
# Sodium current
I_Na = (self.gNa * m * m * m * h) * (V - self.ENa)
# Potassium current
n2 = n * n
I_K = (self.gK * n2 * n2) * (V - self.EK)
# Leak current
I_leak = self.gL * (V - self.EL)
# Total current
dVdt = (- I_Na - I_K - I_leak + I) / self.C
return dVdt
def update(self, x=0. * u.mA / u.cm ** 2):
"""
Update neuron state for one time step.
Args:
x: Input current density
Returns:
spike: Boolean spike indicator
"""
t = brainstate.environ.get('t')
# Update voltage and gating variables using exponential Euler
V = brainstate.nn.exp_euler_step(
self.dV, self.V.value, t,
self.m.value, self.h.value, self.n.value, x
)
m = brainstate.nn.exp_euler_step(self.dm, self.m.value, t, self.V.value)
h = brainstate.nn.exp_euler_step(self.dh, self.h.value, t, self.V.value)
n = brainstate.nn.exp_euler_step(self.dn, self.n.value, t, self.V.value)
# Detect spike (threshold crossing)
spike = jnp.logical_and(self.V.value < self.V_th, V >= self.V_th)
# Update states
self.V.value = V
self.m.value = m
self.h.value = h
self.n.value = n
return spike
Simulating the HH Neuron#
Create and Initialize Neuron#
# Create a population of 10 HH neurons
hh = HH(10)
# Initialize states
brainstate.nn.init_all_states(hh)
# Set simulation parameters
dt = 0.01 * u.ms
brainstate.environ.set(dt=dt)
print(f"Created {hh.varshape} HH neurons")
print(f"Initial membrane potential: {hh.V.value[0]}")
Created (10,) HH neurons
Initial membrane potential: -65.0 * mvolt
Define Simulation Function#
def run(t, inp):
"""Run neuron for one time step.
Args:
t: Current time
inp: Input current
Returns:
V: Membrane voltage
"""
with brainstate.environ.context(t=t, dt=dt):
hh(inp)
return hh.V.value
Run Simulation with Random Input#
# Simulation duration
duration = 100. * u.ms
times = u.math.arange(0. * u.ms, duration, dt)
# Generate random input currents
inputs = brainstate.random.uniform(1., 10., times.shape) * u.uA / u.cm ** 2
print(f"Running simulation for {duration}...")
# Run simulation with progress bar
vs = brainstate.transform.for_loop(
run,
times,
inputs,
pbar=brainstate.transform.ProgressBar(count=100)
)
print("Simulation complete!")
Running simulation for 100.0 * msecond...
Simulation complete!
Visualizing Results#
Plot Membrane Voltage Traces#
# Convert to milliseconds and millivolts for plotting
times_ms = times.to_decimal(u.ms)
vs_mv = vs.to_decimal(u.mV)
# Plot voltage traces for first 3 neurons
fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)
for i in range(3):
axes[i].plot(times_ms, vs_mv[:, i], linewidth=1.5)
axes[i].set_ylabel(f'V{i} (mV)', fontsize=11)
axes[i].grid(True, alpha=0.3)
axes[i].axhline(y=20, color='r', linestyle='--', alpha=0.5, label='Threshold')
if i == 0:
axes[i].legend(fontsize=9)
axes[2].set_xlabel('Time (ms)', fontsize=11)
plt.suptitle('Hodgkin-Huxley Neuron Activity', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
Analyze Spiking Statistics#
# Detect spikes (threshold crossings)
import numpy as np
threshold = 20.0 # mV
spike_times = []
spike_counts = []
for i in range(hh.varshape[0]):
# Find threshold crossings
above_threshold = vs_mv[:, i] > threshold
spike_indices = np.where(np.diff(above_threshold.astype(int)) > 0)[0]
spike_times.append(times_ms[spike_indices])
spike_counts.append(len(spike_indices))
print("Spike counts per neuron:")
for i, count in enumerate(spike_counts[:5]):
print(f" Neuron {i}: {count} spikes")
print(f"\nAverage firing rate: {np.mean(spike_counts) / (duration.to_decimal(u.ms) / 1000):.2f} Hz")
Spike counts per neuron:
Neuron 0: 6 spikes
Neuron 1: 6 spikes
Neuron 2: 6 spikes
Neuron 3: 6 spikes
Neuron 4: 6 spikes
Average firing rate: 60.00 Hz
Raster Plot#
plt.figure(figsize=(12, 6))
for i, spikes in enumerate(spike_times[:10]):
plt.scatter(spikes, [i] * len(spikes), marker='|', s=100, color='black')
plt.xlabel('Time (ms)', fontsize=12)
plt.ylabel('Neuron Index', fontsize=12)
plt.title('Spike Raster Plot', fontsize=14, fontweight='bold')
plt.ylim(-0.5, 9.5)
plt.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.show()
Next Steps#
Network Models: Build networks of neurons with synapses
Learning Rules: Implement STDP and other plasticity mechanisms
Brain Regions: Model cortical columns, hippocampus, etc.
Spiking Networks: Training spiking neural networks