Tutorial 2: Building Inputs with the Functional API#

The Functional API provides a traditional approach to generating current inputs through direct function calls. This tutorial covers all functional input types with comprehensive examples.

Learning Objectives#

By the end of this tutorial, you will:

  • Master all functional input types in braintools.input

  • Understand when to use the functional vs composable API

  • Generate complex waveforms and stochastic processes

  • Create realistic experimental protocols

  • Optimize performance for large-scale simulations

Setup#

import brainstate
import brainunit as u
import matplotlib.pyplot as plt
import numpy as np
from scipy import signal as scipy_signal

import braintools

# Set up plotting
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 11
# Simulation parameters
dt = 0.1 * u.ms
brainstate.environ.set(dt=dt)
def plot_signal(signal, duration, title, ax=None):
    """Helper function to plot signals."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 4))

    dt_val = u.get_magnitude(dt)
    duration_val = u.get_magnitude(duration)
    time_points = np.arange(0, duration_val, dt_val)
    signal_val = u.get_magnitude(signal)

    ax.plot(time_points, signal_val, 'b-', linewidth=1.5)
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('Current (nA)')
    ax.set_title(title)
    ax.grid(True, alpha=0.3)

    return ax

1. Basic Functional Inputs#

The functional API provides direct functions that return current arrays immediately.

1.1 Constant Currents#

# Simple constant current
const_simple = braintools.input.constant([(5.0, 500 * u.ms)])[0]

# Multi-phase constant current
const_multi = braintools.input.constant([
    (2.0, 200 * u.ms),  # 2 nA for 200 ms
    (8.0, 300 * u.ms),  # 8 nA for 300 ms
    (1.0, 100 * u.ms)  # 1 nA for 100 ms
])[0]

# Holding current for patch clamp
holding = braintools.input.constant([(-0.5, 1000 * u.ms)])[0]  # -0.5 nA holding

# Plot examples
fig, axes = plt.subplots(3, 1, figsize=(12, 10))

plot_signal(const_simple, 500 * u.ms, 'Simple Constant: 5 nA', axes[0])
plot_signal(const_multi, 600 * u.ms, 'Multi-phase Constant Current', axes[1])
plot_signal(holding, 1000 * u.ms, 'Holding Current: -0.5 nA', axes[2])

plt.tight_layout()
plt.show()
../_images/41842ad564764ebd89a225145ea31397d28c7a8ca0824fb5477e4dddacfff3e4.png

1.2 Step Functions#

# Classic step protocol for I-V measurements
iv_steps = braintools.input.step(
    [-2, 0, 2, 5, 8, 10, 5, 0],
    [50 * u.ms, 50 * u.ms, 50 * u.ms, 100 * u.ms, 100 * u.ms, 100 * u.ms, 50 * u.ms, 50 * u.ms],
    duration=550 * u.ms
)

# Voltage clamp family (-80 to +40 mV in 10 mV steps)
# Note: these would be voltage values in real voltage clamp
vc_family = braintools.input.step(
    [-80, -60, -40, -20, 0, 20, 40],
    [100 * u.ms] * 7,
    duration=700 * u.ms
)

# Brief depolarizing steps
brief_steps = braintools.input.step(
    [0, 15, 0, 20, 0, 25, 0],
    [200 * u.ms, 10 * u.ms, 200 * u.ms, 10 * u.ms, 200 * u.ms, 10 * u.ms, 200 * u.ms],
    duration=830 * u.ms
)

fig, axes = plt.subplots(3, 1, figsize=(12, 12))

plot_signal(iv_steps, 550 * u.ms, 'I-V Step Protocol', axes[0])
plot_signal(vc_family, 700 * u.ms, 'Voltage Clamp Family (-80 to +40 mV)', axes[1])
plot_signal(brief_steps, 830 * u.ms, 'Brief Depolarizing Steps', axes[2])

plt.tight_layout()
plt.show()
../_images/61305b2101ddc6c036bd5636b0d766c0d90d0c52bad8e3749c07db17c8e55af8.png

1.3 Ramp Currents#

# Linear ramp for threshold detection
threshold_ramp = braintools.input.ramp(0, 20, 2000 * u.ms)

# Triangular wave using two ramps
ramp_up = braintools.input.ramp(-5, 10, 500 * u.ms)
ramp_down = braintools.input.ramp(10, -5, 500 * u.ms)
triangular = np.concatenate([ramp_up, ramp_down])

# Fast ramp for I-V relationships
fast_ramp = braintools.input.ramp(-10, 15, 100 * u.ms)

# Slow adaptation protocol
slow_ramp = braintools.input.ramp(1, 8, 5000 * u.ms)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

plot_signal(threshold_ramp, 2000 * u.ms, 'Threshold Detection Ramp', axes[0])
plot_signal(triangular * u.nA, 1000 * u.ms, 'Triangular Wave (2 ramps)', axes[1])
plot_signal(fast_ramp, 100 * u.ms, 'Fast I-V Ramp', axes[2])
plot_signal(slow_ramp, 5000 * u.ms, 'Slow Adaptation Protocol', axes[3])

plt.tight_layout()
plt.show()
../_images/c6f9cdba7c17cf9ef91503ddb5c49a1928b34a997e697124270f3b8e438cb5e7.png

1.4 Section-based Protocols#

with brainstate.environ.context(dt=dt):
    # Complex experimental protocol
    experiment_protocol = braintools.input.constant([
        (0, 500 * u.ms),  # baseline
        (3, 200 * u.ms),  # small step
        (0, 300 * u.ms),  # return to baseline
        (6, 400 * u.ms),  # medium step
        (0, 200 * u.ms),  # return
        (10, 300 * u.ms),  # large step
        (0, 500 * u.ms)  # recovery
    ])[0]

    # Paired pulse protocol
    paired_pulse = braintools.input.constant([
        (0, 100 * u.ms),  # baseline
        (8, 5 * u.ms),  # first pulse
        (0, 45 * u.ms),  # inter-pulse interval
        (8, 5 * u.ms),  # second pulse
        (0, 100 * u.ms)  # recovery
    ])[0]

    # Hyperpolarizing pre-pulse protocol
    prepulse = braintools.input.constant([
        (0, 200 * u.ms),  # baseline
        (-5, 100 * u.ms),  # hyperpolarizing prepulse
        (15, 50 * u.ms),  # test pulse
        (0, 200 * u.ms)  # recovery
    ])[0]

fig, axes = plt.subplots(3, 1, figsize=(12, 12))

plot_signal(experiment_protocol, 2400 * u.ms, 'Complex Experimental Protocol', axes[0])
plot_signal(paired_pulse, 255 * u.ms, 'Paired Pulse Protocol (50 ms interval)', axes[1])
plot_signal(prepulse, 550 * u.ms, 'Hyperpolarizing Pre-pulse Protocol', axes[2])

plt.tight_layout()
plt.show()
../_images/d96f3d16dcf858098912f9b5c011fc702b1afcc030bff01d30b590968b591233.png

2. Waveform Functions#

Generate periodic and complex waveforms for rhythmic stimulation.

2.1 Sinusoidal Currents#

with brainstate.environ.context(dt=dt):
    # Theta rhythm (8 Hz)
    theta = braintools.input.sinusoidal(3.0, 8 * u.Hz, 1000 * u.ms)

    # Gamma rhythm (40 Hz) 
    gamma = braintools.input.sinusoidal(1.5, 40 * u.Hz, 1000 * u.ms)

    # Low frequency for membrane oscillations
    membrane_osc = braintools.input.sinusoidal(2.0, 2 * u.Hz, 2000 * u.ms)

    # High frequency for channel kinetics
    channel_test = braintools.input.sinusoidal(0.5, 100 * u.Hz, 200 * u.ms)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

plot_signal(theta, 1000 * u.ms, 'Theta Rhythm: 3 nA, 8 Hz', axes[0])
plot_signal(gamma, 1000 * u.ms, 'Gamma Rhythm: 1.5 nA, 40 Hz', axes[1])
plot_signal(membrane_osc, 2000 * u.ms, 'Membrane Oscillation: 2 nA, 2 Hz', axes[2])
plot_signal(channel_test, 200 * u.ms, 'Channel Kinetics Test: 0.5 nA, 100 Hz', axes[3])

plt.tight_layout()
plt.show()
../_images/ca89d749f86dad9938f24e606a9562b78f37a7ecd64a7cb3a8b35d6fda9b88a4.png

2.2 Square Waves#

with brainstate.environ.context(dt=dt):
    # Standard square wave (50% duty cycle)
    standard_square = braintools.input.square(4.0, 5 * u.Hz, 800 * u.ms)

    # Pulse train (20% duty cycle) 
    pulse_train = braintools.input.square(6.0, 10 * u.Hz, 500 * u.ms, duty_cycle=0.2)

    # Wide pulses (80% duty cycle)
    wide_pulses = braintools.input.square(3.0, 3 * u.Hz, 1000 * u.ms, duty_cycle=0.8)

    # High frequency clock signal
    clock = braintools.input.square(1.0, 50 * u.Hz, 200 * u.ms, duty_cycle=0.5)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

plot_signal(standard_square, 800 * u.ms, 'Standard Square Wave: 4 nA, 5 Hz', axes[0])
plot_signal(pulse_train, 500 * u.ms, 'Pulse Train: 6 nA, 10 Hz, 20% duty', axes[1])
plot_signal(wide_pulses, 1000 * u.ms, 'Wide Pulses: 3 nA, 3 Hz, 80% duty', axes[2])
plot_signal(clock, 200 * u.ms, 'Clock Signal: 1 nA, 50 Hz', axes[3])

plt.tight_layout()
plt.show()
../_images/f573fad4d7ccab8e3512bdaa69c1c4af3feb1437e7ed17f747768b55f2c1c074.png

2.3 Triangular and Sawtooth Waves#

with brainstate.environ.context(dt=dt):
    # Triangular wave for I-V curves
    triangular_iv = braintools.input.triangular(10.0, 1 * u.Hz, 2000 * u.ms)

    # Fast triangular for membrane tests
    fast_triangular = braintools.input.triangular(5.0, 10 * u.Hz, 500 * u.ms)

    # Sawtooth for ramp protocols
    sawtooth_ramp = braintools.input.triangular(8.0, 2 * u.Hz, 1500 * u.ms)

    # Fast reset sawtooth
    fast_sawtooth = braintools.input.triangular(6.0, 20 * u.Hz, 300 * u.ms)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

plot_signal(triangular_iv, 2000 * u.ms, 'Triangular I-V: 10 nA, 1 Hz', axes[0])
plot_signal(fast_triangular, 500 * u.ms, 'Fast Triangular: 5 nA, 10 Hz', axes[1])
plot_signal(sawtooth_ramp, 1500 * u.ms, 'Sawtooth Ramp: 8 nA, 2 Hz', axes[2])
plot_signal(fast_sawtooth, 300 * u.ms, 'Fast Reset Sawtooth: 6 nA, 20 Hz', axes[3])

plt.tight_layout()
plt.show()
../_images/114110dedde251b6861db539291751b98de5a9c2a100be8221831661f6991e95.png

2.4 Chirp Signals (Frequency Sweeps)#

with brainstate.environ.context(dt=dt):
    # Linear chirp for impedance measurements
    linear_chirp = braintools.input.chirp(2.0, 1 * u.Hz, 50 * u.Hz, 3000 * u.ms, method='linear')

    # Logarithmic chirp for resonance detection
    log_chirp = braintools.input.chirp(1.5, 0.1 * u.Hz, 100 * u.Hz, 5000 * u.ms, method='logarithmic')

    # Reverse chirp (high to low frequency)
    reverse_chirp = braintools.input.chirp(3.0, 100 * u.Hz, 1 * u.Hz, 2000 * u.ms)

    # Theta-gamma chirp for neural oscillations
    neural_chirp = braintools.input.chirp(1.0, 4 * u.Hz, 80 * u.Hz, 10000 * u.ms, method='logarithmic')

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

plot_signal(linear_chirp, 3000 * u.ms, 'Linear Chirp: 1-50 Hz in 3s', axes[0])
# Show only first 2 seconds of log chirp
plot_signal(log_chirp[:20000], 2000 * u.ms, 'Log Chirp: 0.1-100 Hz (first 2s)', axes[1])
plot_signal(reverse_chirp, 2000 * u.ms, 'Reverse Chirp: 100-1 Hz', axes[2])
# Show only first 2 seconds of neural chirp
plot_signal(neural_chirp[:20000], 2000 * u.ms, 'Neural Chirp: 4-80 Hz (first 2s)', axes[3])

plt.tight_layout()
plt.show()
../_images/f85bf5c6f8e0f3b5bf59608429828cb7c852dd4cf1d2fbd5cec812eb1b87fcc5.png

3. Pulse Functions#

Generate realistic synaptic-like pulses and bursts.

3.1 Spike Trains#

with brainstate.environ.context(dt=dt):
    spikes = braintools.input.spike(
        sp_times=[10, 50, 90] * u.ms,
        sp_lens=1 * u.ms,
        sp_sizes=0.8 * u.nA,
        duration=150 * u.ms,
    )
    burst = braintools.input.burst(
        burst_amp=0.5 * u.nA,
        burst_freq=20 * u.Hz,
        burst_duration=20 * u.ms,
        inter_burst_interval=50 * u.ms,
        n_bursts=3,
        duration=250 * u.ms,
    )

fig, axes = plt.subplots(2, 1, figsize=(14, 10))
axes = axes.flatten()

plot_signal(spikes, 150 * u.ms, 'spike input', axes[0])
plot_signal(burst, 250 * u.ms, 'High Frequency Burst: 200 Hz', axes[1])

plt.tight_layout()
plt.show()
../_images/c03d852cccd20ae6ac84879a4e3bc61556cee923902137c65e7f9fe2f683a667.png

3.2 Gaussian Pulses#

with brainstate.environ.context(dt=dt):
    # Narrow Gaussian pulse (synaptic-like)
    narrow_pulse = braintools.input.gaussian_pulse(5.0, 100 * u.ms, 2 * u.ms, 250 * u.ms)

    # Wide Gaussian pulse (neuromodulation-like)
    wide_pulse = braintools.input.gaussian_pulse(3.0, 200 * u.ms, 20 * u.ms, 500 * u.ms)

    # Multiple Gaussian pulses
    multi_gaussian = (
        braintools.input.gaussian_pulse(4.0, 100 * u.ms, 5 * u.ms, 400 * u.ms) +
        braintools.input.gaussian_pulse(3.0, 200 * u.ms, 8 * u.ms, 400 * u.ms) +
        braintools.input.gaussian_pulse(5.0, 300 * u.ms, 3 * u.ms, 400 * u.ms)
    )

    # Gaussian envelope for burst
    envelope = braintools.input.gaussian_pulse(1.0, 150 * u.ms, 30 * u.ms, 300 * u.ms)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

plot_signal(narrow_pulse, 250 * u.ms, 'Narrow Gaussian: σ=2ms, Synaptic-like', axes[0])
plot_signal(wide_pulse, 500 * u.ms, 'Wide Gaussian: σ=20ms, Neuromodulation', axes[1])
plot_signal(multi_gaussian, 400 * u.ms, 'Multiple Gaussian Pulses', axes[2])
plot_signal(envelope, 300 * u.ms, 'Gaussian Envelope: σ=30ms', axes[3])

plt.tight_layout()
plt.show()
../_images/cfc95b042cc1c5be8b29ab6b6bbd1b7d10d540150e079dfff5c5da419a52b73e.png

3.3 Exponential Pulses#

with brainstate.environ.context(dt=dt):
    # EPSC-like single exponential
    epsc = braintools.input.exponential_decay(8.0, 50 * u.ms, 200 * u.ms)

    # IPSC-like single exponential (longer decay)
    ipsc = braintools.input.exponential_decay(6.0, 100 * u.ms, 300 * u.ms)

    # Double exponential (realistic synapse)
    double_exp = braintools.input.double_exponential(10.0, 1 * u.ms, 15 * u.ms, 250 * u.ms)

    # Fast AMPA-like response
    ampa_like = braintools.input.double_exponential(12.0, 0.5 * u.ms, 5 * u.ms, 200 * u.ms)

    # Slow NMDA-like response  
    nmda_like = braintools.input.double_exponential(8.0, 2 * u.ms, 50 * u.ms, 300 * u.ms)

fig, axes = plt.subplots(3, 2, figsize=(14, 12))
axes = axes.flatten()

plot_signal(epsc, 200 * u.ms, 'EPSC-like: τ=10ms', axes[0])
plot_signal(ipsc, 300 * u.ms, 'IPSC-like: τ=20ms', axes[1])
plot_signal(double_exp, 250 * u.ms, 'Double Exponential: τ₁=1ms, τ₂=15ms', axes[2])
plot_signal(ampa_like, 200 * u.ms, 'AMPA-like: Fast (0.5/5 ms)', axes[3])
plot_signal(nmda_like, 300 * u.ms, 'NMDA-like: Slow (2/50 ms)', axes[4])

# Leave last subplot empty
axes[5].axis('off')

plt.tight_layout()
plt.show()
../_images/2efaf96067de1e9d787053006cd11e5d6d5bd3399e3ddd3dd8dc0f532680b6de.png

4. Stochastic Processes#

Generate realistic noise and random processes for neural modeling.

4.1 Wiener Processes (Brownian Motion)#

with brainstate.environ.context(dt=dt):
    # Standard Wiener process for channel noise
    channel_noise = braintools.input.wiener_process(sigma=0.5, duration=1000 * u.ms, seed=42)

    # Strong Wiener process for membrane fluctuations
    membrane_noise = braintools.input.wiener_process(sigma=2.0, duration=500 * u.ms, seed=123)

    # Weak Wiener process for baseline fluctuations
    baseline_noise = braintools.input.wiener_process(sigma=0.1, duration=2000 * u.ms, seed=456)

    # Multiple realizations for comparison
    wiener_1 = braintools.input.wiener_process(sigma=1.0, duration=800 * u.ms, seed=10)
    wiener_2 = braintools.input.wiener_process(sigma=1.0, duration=800 * u.ms, seed=20)
    wiener_3 = braintools.input.wiener_process(sigma=1.0, duration=800 * u.ms, seed=30)

fig, axes = plt.subplots(3, 2, figsize=(14, 12))

plot_signal(channel_noise, 1000 * u.ms, 'Channel Noise: σ=0.5 nA', axes[0, 0])
plot_signal(membrane_noise, 500 * u.ms, 'Membrane Noise: σ=2.0 nA', axes[0, 1])
plot_signal(baseline_noise, 2000 * u.ms, 'Baseline Noise: σ=0.1 nA', axes[1, 0])

# Plot multiple realizations on same axes
ax = axes[1, 1]
dt_val = u.get_magnitude(dt)
time_points = np.arange(0, 800, dt_val)
ax.plot(time_points, u.get_magnitude(wiener_1), 'b-', alpha=0.7, label='Realization 1')
ax.plot(time_points, u.get_magnitude(wiener_2), 'r-', alpha=0.7, label='Realization 2')
ax.plot(time_points, u.get_magnitude(wiener_3), 'g-', alpha=0.7, label='Realization 3')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Current (nA)')
ax.set_title('Multiple Wiener Realizations: σ=1.0 nA')
ax.legend()
ax.grid(True, alpha=0.3)

# Statistics analysis
ax = axes[2, 0]
# Power spectral density
f, psd = scipy_signal.welch(u.get_magnitude(channel_noise), fs=1000 / u.get_magnitude(dt))
ax.loglog(f[1:], psd[1:])
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('PSD (nA²/Hz)')
ax.set_title('Wiener Process PSD (should be flat)')
ax.grid(True, alpha=0.3)

# Histogram
ax = axes[2, 1]
ax.hist(u.get_magnitude(channel_noise), bins=50, density=True, alpha=0.7)
ax.set_xlabel('Current (nA)')
ax.set_ylabel('Probability Density')
ax.set_title('Wiener Process Distribution')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
../_images/a8cef90b9b010fda71be98a282fa8a7989ac288032c0ff7b20df8763054da437.png

4.2 Ornstein-Uhlenbeck Processes (Colored Noise)#

with brainstate.environ.context(dt=dt):
    # Fast OU process (short correlation time)
    fast_ou = braintools.input.ou_process(0., sigma=1.0, tau=5 * u.ms, duration=500 * u.ms, seed=42)

    # Slow OU process (long correlation time)
    slow_ou = braintools.input.ou_process(0., sigma=1.0, tau=50 * u.ms, duration=500 * u.ms, seed=42)

    # Synaptic background activity
    synaptic_bg = braintools.input.ou_process(1., sigma=0.8, tau=20 * u.ms, duration=1000 * u.ms, seed=123)

    # Strong correlated noise
    strong_ou = braintools.input.ou_process(0., sigma=3.0, tau=100 * u.ms, duration=800 * u.ms, seed=456)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

plot_signal(fast_ou, 500 * u.ms, 'Fast OU Process: τ=5ms', axes[0])
plot_signal(slow_ou, 500 * u.ms, 'Slow OU Process: τ=50ms', axes[1])
plot_signal(synaptic_bg, 1000 * u.ms, 'Synaptic Background: μ=1nA, τ=20ms', axes[2])
plot_signal(strong_ou, 800 * u.ms, 'Strong OU: σ=3nA, τ=100ms', axes[3])

plt.tight_layout()
plt.show()

# Compare OU processes with different correlation times
duration = 1000 * u.ms
ou_1ms = braintools.input.ou_process(0., sigma=1.0, tau=1 * u.ms, duration=duration, seed=42)
ou_10ms = braintools.input.ou_process(0., sigma=1.0, tau=10 * u.ms, duration=duration, seed=42)
ou_100ms = braintools.input.ou_process(0., sigma=1.0, tau=100 * u.ms, duration=duration, seed=42)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Time series comparison
dt_val = u.get_magnitude(dt)
time_points = np.arange(0, 1000, dt_val)
ax = axes[0, 0]
ax.plot(time_points, u.get_magnitude(ou_1ms), 'b-', alpha=0.8, label='τ=1ms')
ax.plot(time_points, u.get_magnitude(ou_10ms), 'r-', alpha=0.8, label='τ=10ms')
ax.plot(time_points, u.get_magnitude(ou_100ms), 'g-', alpha=0.8, label='τ=100ms')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Current (nA)')
ax.set_title('OU Processes with Different τ')
ax.legend()
ax.grid(True, alpha=0.3)

# Autocorrelation comparison
from scipy.signal import correlate


def autocorr(x):
    x = x - np.mean(x)
    autocorr = correlate(x, x, mode='full')
    autocorr = autocorr[autocorr.size // 2:]
    return autocorr / autocorr[0]


ax = axes[0, 1]
lags = np.arange(0, 200) * dt_val  # First 200 points
ax.plot(lags, autocorr(u.get_magnitude(ou_1ms))[:200], 'b-', label='τ=1ms')
ax.plot(lags, autocorr(u.get_magnitude(ou_10ms))[:200], 'r-', label='τ=10ms')
ax.plot(lags, autocorr(u.get_magnitude(ou_100ms))[:200], 'g-', label='τ=100ms')
ax.set_xlabel('Lag (ms)')
ax.set_ylabel('Autocorrelation')
ax.set_title('Autocorrelation Functions')
ax.legend()
ax.grid(True, alpha=0.3)

# Power spectral densities
ax = axes[1, 0]
fs = 1000 / dt_val
f1, psd1 = scipy_signal.welch(u.get_magnitude(ou_1ms), fs=fs)
f2, psd2 = scipy_signal.welch(u.get_magnitude(ou_10ms), fs=fs)
f3, psd3 = scipy_signal.welch(u.get_magnitude(ou_100ms), fs=fs)

ax.loglog(f1[1:], psd1[1:], 'b-', label='τ=1ms')
ax.loglog(f2[1:], psd2[1:], 'r-', label='τ=10ms')
ax.loglog(f3[1:], psd3[1:], 'g-', label='τ=100ms')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('PSD (nA²/Hz)')
ax.set_title('Power Spectral Densities')
ax.legend()
ax.grid(True, alpha=0.3)

# Statistics
axes[1, 1].axis('off')
stats_text = f"""OU Process Statistics:

τ=1ms:   Mean={np.mean(u.get_magnitude(ou_1ms)):.3f} nA
         Std={np.std(u.get_magnitude(ou_1ms)):.3f} nA

τ=10ms:  Mean={np.mean(u.get_magnitude(ou_10ms)):.3f} nA
         Std={np.std(u.get_magnitude(ou_10ms)):.3f} nA

τ=100ms: Mean={np.mean(u.get_magnitude(ou_100ms)):.3f} nA
         Std={np.std(u.get_magnitude(ou_100ms)):.3f} nA
"""
axes[1, 1].text(0.1, 0.5, stats_text, fontsize=12, family='monospace')

plt.tight_layout()
plt.show()
../_images/5d3d18038bc89cfcee6aaf35cc04309cdcb1d92b4d0e7ba5ae541b3559dba2aa.png ../_images/eb5159d831fa55e009c9418345e13b782381249d4e3a939eeafe085d3cc766fc.png

4.3 Poisson Spike Trains#

with brainstate.environ.context(dt=dt):
    # Low rate Poisson spikes
    low_rate = braintools.input.poisson(amplitude=5.0, rate=10 * u.Hz, duration=1000 * u.ms, seed=42)

    # Medium rate Poisson spikes
    medium_rate = braintools.input.poisson(amplitude=3.0, rate=50 * u.Hz, duration=1000 * u.ms, seed=123)

    # High rate Poisson spikes
    high_rate = braintools.input.poisson(amplitude=2.0, rate=200 * u.Hz, duration=1000 * u.ms, seed=456)

    # Variable amplitude Poisson
    var_poisson = braintools.input.poisson(
        amplitude=np.random.normal(4.0, 1.0, 20000),  # Variable amplitudes
        rate=30 * u.Hz,
        duration=2000 * u.ms,
        seed=789
    )

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

plot_signal(low_rate, 1000 * u.ms, 'Low Rate Poisson: 10 Hz', axes[0])
plot_signal(medium_rate, 1000 * u.ms, 'Medium Rate Poisson: 50 Hz', axes[1])
plot_signal(high_rate, 1000 * u.ms, 'High Rate Poisson: 200 Hz', axes[2])
plot_signal(var_poisson, 2000 * u.ms, 'Variable Amplitude Poisson: 30 Hz', axes[3])

plt.tight_layout()
plt.show()

# Analyze Poisson statistics
# Generate multiple realizations
n_trials = 10
rate = 25 * u.Hz
duration = 1000 * u.ms

spike_counts = []
isis = []  # Inter-spike intervals

for i in range(n_trials):
    poisson_train = braintools.input.poisson(amplitude=1.0, rate=rate, duration=duration, seed=i * 100)

    # Count spikes
    spike_times = np.where(u.get_magnitude(poisson_train) > 0.5)[0] * u.get_magnitude(dt)
    spike_counts.append(len(spike_times))

    # Calculate ISIs
    if len(spike_times) > 1:
        trial_isis = np.diff(spike_times)
        isis.extend(trial_isis)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Spike count distribution
axes[0].hist(spike_counts, bins=10, density=True, alpha=0.7)
expected_count = u.get_magnitude(rate * duration)
axes[0].axvline(expected_count, color='r', linestyle='--',
                label=f'Expected: {expected_count:.1f}')
axes[0].axvline(np.mean(spike_counts), color='g', linestyle='-',
                label=f'Observed: {np.mean(spike_counts):.1f}')
axes[0].set_xlabel('Spike Count')
axes[0].set_ylabel('Probability Density')
axes[0].set_title('Spike Count Distribution (Poisson)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# ISI distribution
axes[1].hist(isis, bins=30, density=True, alpha=0.7)
# Theoretical exponential distribution
isi_theory = np.linspace(0, max(isis), 100)
rate_val = u.get_magnitude(rate)
exp_theory = rate_val * np.exp(-rate_val * isi_theory)
axes[1].plot(isi_theory, exp_theory, 'r-', linewidth=2,
             label='Theoretical Exponential')
axes[1].set_xlabel('Inter-Spike Interval (ms)')
axes[1].set_ylabel('Probability Density')
axes[1].set_title('ISI Distribution (Exponential)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Sample raster plot
axes[2].set_xlim(0, 500)  # Show first 500 ms
for i in range(5):  # Show first 5 trials
    poisson_train = braintools.input.poisson(amplitude=1.0, rate=rate, duration=500 * u.ms, seed=i * 100)
    spike_times = np.where(u.get_magnitude(poisson_train) > 0.5)[0] * u.get_magnitude(dt)
    axes[2].scatter(spike_times, [i] * len(spike_times), s=20, c='black')

axes[2].set_xlabel('Time (ms)')
axes[2].set_ylabel('Trial')
axes[2].set_title('Poisson Raster Plot (25 Hz)')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Poisson Statistics (25 Hz, 1000 ms):")
print(f"Expected spikes per trial: {expected_count:.1f}")
print(f"Observed mean: {np.mean(spike_counts):.1f} ± {np.std(spike_counts):.1f}")
print(f"Coefficient of variation: {np.std(spike_counts) / np.mean(spike_counts):.2f} (should be ~1.0)")
print(f"Mean ISI: {np.mean(isis):.1f} ms (expected: {1000 / rate_val:.1f} ms)")
../_images/3d50d565fa152a58d11e5f3c4c87da709acb87e8477be91ad3d9b5cf7e956f44.png ../_images/9b3c05ccc69b46e47d3fd17eb37e6c9d958b6e9ae7f7bae174d7afe2de904449.png
Poisson Statistics (25 Hz, 1000 ms):
Expected spikes per trial: 25.0
Observed mean: 24.7 ± 5.0
Coefficient of variation: 0.20 (should be ~1.0)
Mean ISI: 39.0 ms (expected: 40.0 ms)

5. Advanced Functional Techniques#

Performance optimization and complex protocol generation.

5.1 Batch Generation#

import time

# Generate multiple sinusoids efficiently
with brainstate.environ.context(dt=dt):
    # Method 1: Loop (slower)
    start_time = time.time()
    frequencies = np.arange(1, 21) * u.Hz  # 1-20 Hz
    sines_loop = []
    for freq in frequencies:
        sine = braintools.input.sinusoidal(1.0, freq, 500 * u.ms)
        sines_loop.append(sine)
    loop_time = time.time() - start_time

    # Method 2: Vectorized (faster for some cases)
    start_time = time.time()
    duration = 500 * u.ms
    dt_val = u.get_magnitude(dt)
    duration_val = u.get_magnitude(duration)
    t = np.arange(0, duration_val, dt_val)

    sines_vectorized = []
    for freq in frequencies:
        freq_val = u.get_magnitude(freq)
        sine = np.sin(2 * np.pi * freq_val * t / 1000) * u.nA
        sines_vectorized.append(sine)
    vectorized_time = time.time() - start_time

print(f"Loop method time: {loop_time:.4f} seconds")
print(f"Vectorized method time: {vectorized_time:.4f} seconds")

# Verify results are similar
diff = np.abs(u.get_magnitude(sines_loop[10]) - u.get_magnitude(sines_vectorized[10]))
print(f"Maximum difference: {np.max(diff):.2e}")
Loop method time: 0.0033 seconds
Vectorized method time: 0.0050 seconds
Maximum difference: 2.98e-08

5.2 Complex Protocol Assembly#

# Build a complex experimental protocol using functional API
with brainstate.environ.context(dt=dt):
    # Protocol: baseline + sine + ramp + spikes + recovery

    # Phase 1: Baseline (500 ms)
    baseline = braintools.input.constant([(0, 500 * u.ms)])[0]

    # Phase 2: Sinusoidal stimulation (1000 ms, 10 Hz)
    sine_stim = braintools.input.sinusoidal(3.0, 10 * u.Hz, 1000 * u.ms)

    # Phase 3: Ramp up (300 ms, 0 to 8 nA)
    ramp_up = braintools.input.ramp(0, 8, 300 * u.ms)

    # Phase 4: Brief spikes on top of constant (400 ms)
    constant_bg = braintools.input.constant([(5, 400 * u.ms)])[0]
    spike_times = [50, 150, 250, 350] * u.ms
    spikes = braintools.input.spike(spike_times, 1. * u.ms, 10.0, 400 * u.ms)
    spikes_on_constant = constant_bg + spikes

    # Phase 5: Recovery ramp (200 ms, 5 to 0 nA)
    ramp_down = braintools.input.ramp(5, 0, 200 * u.ms)

    # Phase 6: Final baseline (300 ms)
    final_baseline = braintools.input.constant([(0, 300 * u.ms)])[0]

    # Assemble complete protocol
    complete_protocol = np.concatenate([
        baseline,
        sine_stim,
        ramp_up,
        spikes_on_constant,
        ramp_down,
        final_baseline
    ])

    total_duration = 500 + 1000 + 300 + 400 + 200 + 300  # 2700 ms

# Plot the complete protocol
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))

# Full protocol
plot_signal(complete_protocol, total_duration * u.ms, 'Complete Experimental Protocol', ax1)

# Zoomed view of spike phase
spike_start = int((500 + 1000 + 300) / u.get_magnitude(dt))
spike_end = int((500 + 1000 + 300 + 400) / u.get_magnitude(dt))
spike_phase = complete_protocol[spike_start:spike_end]
plot_signal(spike_phase, 400 * u.ms, 'Zoom: Spikes on Constant Background', ax2)

# Add phase annotations
phase_starts = [0, 500, 1500, 1800, 2200, 2400]
phase_labels = ['Baseline', 'Sine\n10 Hz', 'Ramp\nUp', 'Spikes +\nConstant', 'Ramp\nDown', 'Recovery']

for i, (start, label) in enumerate(zip(phase_starts[:-1], phase_labels)):
    end = phase_starts[i + 1]
    ax1.axvspan(start, end, alpha=0.1, color=f'C{i}')
    ax1.text(start + (end - start) / 2, 12, label, ha='center', va='bottom',
             bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

ax1.set_ylim(-2, 14)

plt.tight_layout()
plt.show()

print(f"Protocol duration: {total_duration} ms")
print(f"Array length: {len(complete_protocol)} points")
print(f"Memory usage: ~{len(complete_protocol) * 8 / 1024:.1f} KB")
../_images/e08f450464c37af63b5cc4479595d184eeb0d901cdffe709c0f6f5e1c801ce3d.png
Protocol duration: 2700 ms
Array length: 27000 points
Memory usage: ~210.9 KB

5.3 Parameter Sweeps#

# Generate parameter sweeps for systematic studies
with brainstate.environ.context(dt=dt):
    # Amplitude sweep for I-V curves
    amplitudes = np.linspace(-2, 10, 13)  # -2 to 10 nA in 13 steps
    iv_sweeps = []
    for amp in amplitudes:
        # Each sweep: 100ms baseline + 200ms step + 100ms recovery
        iv_sweeps.append(
            braintools.input.section([0, amp, 0], [100, 200, 100] * u.ms)
        )

    # Frequency sweep for resonance testing
    frequencies = np.logspace(0, 2, 10) * u.Hz  # 1 to 100 Hz, log spaced
    freq_sweeps = []
    for freq in frequencies:
        freq_sweeps.append(
            braintools.input.sinusoidal(2.0, freq, 500 * u.ms)
        )

    # Noise amplitude sweep
    noise_sigmas = [0.1, 0.5, 1.0, 2.0, 5.0]
    noise_sweeps = []
    for sigma in noise_sigmas:
        # OU process with constant mean
        noise_sweeps.append(
            braintools.input.ou_process(sigma=sigma, tau=10 * u.ms, mean=2.0, duration=800 * u.ms, seed=42)
        )

# Visualize parameter sweeps
fig, axes = plt.subplots(3, 1, figsize=(15, 15))

# I-V sweep waterfall plot
ax = axes[0]
for i, (sweep, amp) in enumerate(zip(iv_sweeps[::2], amplitudes[::2])):
    time_points = np.arange(0, 400, u.get_magnitude(dt))
    offset = i * 3  # Vertical offset
    ax.plot(time_points, u.get_magnitude(sweep) + offset,
            label=f'{amp:.1f} nA' if i % 2 == 0 else '')

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Current + Offset (nA)')
ax.set_title('I-V Sweep (Every 2nd Trace Shown)')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.grid(True, alpha=0.3)

# Frequency sweep spectrogram-style
ax = axes[1]
for i, (sweep, freq) in enumerate(zip(freq_sweeps, frequencies)):
    time_points = np.arange(0, 500, u.get_magnitude(dt))
    offset = i * 1.5
    ax.plot(time_points, u.get_magnitude(sweep) + offset,
            label=f'{u.get_magnitude(freq):.1f} Hz' if i % 2 == 0 else '')

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Current + Offset (nA)')
ax.set_title('Frequency Sweep (Log Spaced, 1-100 Hz)')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.grid(True, alpha=0.3)

# Noise amplitude sweep
ax = axes[2]
for i, (noise, sigma) in enumerate(zip(noise_sweeps, noise_sigmas)):
    time_points = np.arange(0, 800, u.get_magnitude(dt))
    offset = i * 4
    ax.plot(time_points, u.get_magnitude(noise) + offset, alpha=0.8,
            label=f'σ={sigma} nA')

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Current + Offset (nA)')
ax.set_title('OU Process Noise Amplitude Sweep')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Generated {len(iv_sweeps)} I-V traces")
print(f"Generated {len(freq_sweeps)} frequency responses")
print(f"Generated {len(noise_sweeps)} noise conditions")
print(f"Frequency range: {u.get_magnitude(frequencies[0]):.1f} - {u.get_magnitude(frequencies[-1]):.1f} Hz")
../_images/72b3fd3a0b7235d3ae2cc263dbd10654b75a3354527a521b04c00af1cccfdb7f.png
Generated 13 I-V traces
Generated 10 frequency responses
Generated 5 noise conditions
Frequency range: 1.0 - 100.0 Hz

6. Real-World Applications#

Examples of experimental protocols you might actually use.

6.1 Patch Clamp Protocols#

# Standard patch clamp protocols using functional API

with brainstate.environ.context(dt=dt):
    # 1. Rheobase determination (find threshold current)
    rheobase_currents = np.arange(0, 10, 0.5)  # 0 to 10 nA in 0.5 nA steps
    rheobase_protocols = []

    for current in rheobase_currents:
        protocol = braintools.input.constant([
            (0, 200 * u.ms),  # baseline
            (current, 500 * u.ms),  # test current
            (0, 200 * u.ms)  # recovery
        ])[0]
        rheobase_protocols.append(protocol)

    # 2. Input resistance measurement
    # Small hyperpolarizing steps to stay in linear range
    test_currents = [-0.1, -0.2, -0.3, -0.4, -0.5]  # Small negative currents
    rin_protocols = []

    for current in test_currents:
        protocol = braintools.input.constant([
            (0, 100 * u.ms),
            (current, 300 * u.ms),
            (0, 100 * u.ms)
        ])[0]
        rin_protocols.append(protocol)

    # 3. Adaptation protocol
    adaptation_protocol = braintools.input.constant([
        (0, 500 * u.ms),  # baseline
        (8, 2000 * u.ms),  # long depolarizing step
        (0, 500 * u.ms)  # recovery
    ])[0]

    # 4. Frequency-current (F-I) relationship
    fi_currents = np.arange(2, 20, 2)  # 2 to 18 nA in 2 nA steps
    fi_protocols = []

    for current in fi_currents:
        protocol = braintools.input.constant([
            (0, 200 * u.ms),
            (current, 1000 * u.ms),  # 1 second for rate measurement
            (0, 200 * u.ms)
        ])[0]
        fi_protocols.append(protocol)

# Plot representative protocols
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Rheobase (show every 4th trace)
ax = axes[0, 0]
for i in range(0, len(rheobase_protocols), 4):
    time_points = np.arange(0, 900, u.get_magnitude(dt))
    offset = i * 0.5
    ax.plot(time_points, u.get_magnitude(rheobase_protocols[i]) + offset,
            label=f'{rheobase_currents[i]:.1f} nA')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Current + Offset (nA)')
ax.set_title('Rheobase Determination')
ax.legend()
ax.grid(True, alpha=0.3)

# Input resistance
ax = axes[0, 1]
for i, (protocol, current) in enumerate(zip(rin_protocols, test_currents)):
    time_points = np.arange(0, 500, u.get_magnitude(dt))
    offset = i * 0.2
    ax.plot(time_points, u.get_magnitude(protocol) + offset,
            label=f'{current:.1f} nA')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Current + Offset (nA)')
ax.set_title('Input Resistance Protocol')
ax.legend()
ax.grid(True, alpha=0.3)

# Adaptation
plot_signal(adaptation_protocol, 3000 * u.ms, 'Adaptation Protocol', axes[1, 0])

# F-I relationship (show subset)
ax = axes[1, 1]
for i in range(0, len(fi_protocols), 2):
    time_points = np.arange(0, 1400, u.get_magnitude(dt))
    offset = i * 2
    ax.plot(time_points, u.get_magnitude(fi_protocols[i]) + offset,
            label=f'{fi_currents[i]:.0f} nA')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Current + Offset (nA)')
ax.set_title('F-I Relationship Protocol')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Generated {len(rheobase_protocols)} rheobase protocols")
print(f"Generated {len(rin_protocols)} input resistance protocols")
print(f"Generated {len(fi_protocols)} F-I protocols")
print(f"Total protocols: {len(rheobase_protocols) + len(rin_protocols) + len(fi_protocols) + 1}")
../_images/27429089e3865a0acbf92839ab5d67cfd041d6d0da62e5166895a34e5f33fa39.png
Generated 20 rheobase protocols
Generated 5 input resistance protocols
Generated 9 F-I protocols
Total protocols: 35

6.2 Dynamic Clamp Simulation#

import functools
with brainstate.environ.context(dt=dt):
      duration = 2000 * u.ms

      # 1. Background synaptic activity (OU process)
      # Fixed: Corrected parameter order (mean, sigma, tau, duration)
      excitatory_bg = braintools.input.ou_process(
          mean=0.5, sigma=0.8, tau=15 * u.ms, duration=duration, seed=42
      ) * u.nA

      inhibitory_bg = braintools.input.ou_process(
          mean=-0.8, sigma=1.2, tau=25 * u.ms, duration=duration, seed=123
      ) * u.nA

      # 2. Periodic excitatory inputs (theta rhythm)
      theta_times = np.arange(200, 1800, 125)  # 8 Hz theta (in ms)
      theta_inputs = []
      for t in theta_times:
          # Each theta event triggers multiple gamma spikes
          gamma_times = t + np.array([0, 10, 20, 30, 40])  # in ms
          for gt in gamma_times:
              if gt < u.get_magnitude(duration):
                  # Fixed: Corrected parameter order (amplitude, tau_rise, tau_decay, duration, t_start)
                  epsc = braintools.input.double_exponential(
                      amplitude=3.0,
                      tau_rise=1 * u.ms,
                      tau_decay=8 * u.ms,
                      duration=duration,
                      t_start=gt * u.ms
                  )
                  theta_inputs.append(epsc)

      # Sum all theta inputs - Fixed: Proper array handling
      if theta_inputs:
          theta_total = functools.reduce(lambda a, b: a + b, theta_inputs) * u.nA
      else:
          theta_total = np.zeros(int(u.get_magnitude(duration) / 0.05)) * u.nA

      # 3. Random excitatory events (Poisson)
      poisson_spikes = braintools.input.poisson(
          rate=15 * u.Hz,
          duration=duration,
          amplitude=1.0,
          seed=456
      )

      # Convert spikes to EPSCs
      spike_indices = np.where(u.get_magnitude(poisson_spikes) > 0.5)[0]
      spike_times = spike_indices * 0.05  # Convert to ms (assuming dt = 0.05 ms)
      random_epscs = []

      for st in spike_times[:20]:  # Limit for performance
          if st < u.get_magnitude(duration):
              epsc = braintools.input.double_exponential(
                  amplitude=2.5,
                  tau_rise=0.8 * u.ms,
                  tau_decay=6 * u.ms,
                  duration=duration,
                  t_start=st * u.ms
              )
              random_epscs.append(epsc)

      if random_epscs:
          random_total = functools.reduce(lambda a, b: a + b, random_epscs) * u.nA
      else:
          random_total = np.zeros(int(u.get_magnitude(duration) / 0.05)) * u.nA

      # 4. Combine all inputs
      total_current = excitatory_bg + inhibitory_bg + theta_total + random_total

      # 5. Add a brief test pulse
      # Fixed: Corrected section function usage (values, durations)
      test_pulse = braintools.input.section(
          values=[0, 5, 0],
          durations=[1000 * u.ms, 100 * u.ms, 900 * u.ms]
      ) * u.nA

      final_protocol = total_current + test_pulse

# Plot the dynamic clamp simulation
fig, axes = plt.subplots(4, 1, figsize=(15, 16))

# Individual components
plot_signal(excitatory_bg, duration, 'Excitatory Background (OU Process)', axes[0])
plot_signal(inhibitory_bg, duration, 'Inhibitory Background (OU Process)', axes[1])
plot_signal(theta_total + random_total, duration, 'Synaptic Events (Theta + Random)', axes[2])
plot_signal(final_protocol, duration, 'Complete Dynamic Clamp Protocol', axes[3])

# Mark test pulse
axes[3].axvspan(1000, 1100, alpha=0.3, color='red', label='Test Pulse')
axes[3].legend()

plt.tight_layout()
plt.show()

# Statistics
print("Dynamic Clamp Statistics:")
print(
  f"Excitatory background: {np.mean(u.get_magnitude(excitatory_bg)):.2f} ± {np.std(u.get_magnitude(excitatory_bg)):.2f} nA"
)
print(
  f"Inhibitory background: {np.mean(u.get_magnitude(inhibitory_bg)):.2f} ± {np.std(u.get_magnitude(inhibitory_bg)):.2f} nA"
)
print(
  f"Total current range: {np.min(u.get_magnitude(final_protocol)):.2f} to {np.max(u.get_magnitude(final_protocol)):.2f} nA"
)
print(f"Number of theta events: {len(theta_times)}")
print(f"Number of random EPSCs: {len(spike_times)}")
../_images/3919dba7106e60e08f4f353d4026b0d828e4ed495d84ca20b32948c9a80f1d93.png
Dynamic Clamp Statistics:
Excitatory background: 0.72 ± 2.05 nA
Inhibitory background: 0.22 ± 4.69 nA
Total current range: -14.12 to 18.57 nA
Number of theta events: 13
Number of random EPSCs: 28