Simulating asynchronous or Gamma oscillations using AdEx neurons in brainpy.state#

Colab Open in Kaggle

Implementation of the paper:

  • Susin, Eduarda, and Alain Destexhe. “Integration, coincidence detection and resonance in networks of spiking neurons expressing gamma oscillations and asynchronous states.” PLoS computational biology 17.9 (2021): e1009416. https://pmc.ncbi.nlm.nih.gov/articles/PMC8478196/

import braintools
import brainstate
import brainpy.state
import brainunit as u
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import kaiserord, lfilter, firwin, hilbert

Parameters#

# Table 1: specific neuron model parameters
RS_par = dict(
    Vth=-40 * u.mV, delta=2. * u.mV, tau_ref=5. * u.ms, tau_w=500 * u.ms,
    a=4 * u.nS, b=20 * u.pA, C=150 * u.pF, gL=10 * u.nS, EL=-65 * u.mV, V_reset=-65 * u.mV,
    E_e=0. * u.mV, E_i=-80. * u.mV
)
FS_par = dict(
    Vth=-47.5 * u.mV, delta=0.5 * u.mV, tau_ref=5. * u.ms, tau_w=500 * u.ms,
    a=0 * u.nS, b=0 * u.pA, C=150 * u.pF, gL=10 * u.nS, EL=-65 * u.mV, V_reset=-65 * u.mV,
    E_e=0. * u.mV, E_i=-80. * u.mV
)
Ch_par = dict(
    Vth=-47.5 * u.mV, delta=0.5 * u.mV, tau_ref=1. * u.ms, tau_w=50 * u.ms,
    a=80 * u.nS, b=150 * u.pA, C=150 * u.pF, gL=10 * u.nS, EL=-58 * u.mV, V_reset=-65 * u.mV,
    E_e=0. * u.mV, E_i=-80. * u.mV,
)

Neuron model#

class AdEx(brainpy.state.Neuron):
    def __init__(
        self,
        in_size,
        # neuronal parameters
        Vth=-40 * u.mV, delta=2. * u.mV, tau_ref=5. * u.ms, tau_w=500 * u.ms,
        a=4 * u.nS, b=20 * u.pA, C=150 * u.pF,
        gL=10 * u.nS, EL=-65 * u.mV, V_reset=-65 * u.mV, V_sp_th=-40. * u.mV,
        # synaptic parameters
        tau_e=1.5 * u.ms, tau_i=7.5 * u.ms, E_e=0. * u.mV, E_i=-80. * u.mV,
        # other parameters
        V_initializer=braintools.init.Uniform(-65., -50., unit=u.mV),
        w_initializer=braintools.init.Constant(0. * u.pA),
        ge_initializer=braintools.init.Constant(0. * u.nS),
        gi_initializer=braintools.init.Constant(0. * u.nS),
    ):
        super().__init__(in_size=in_size)

        # neuronal parameters
        self.Vth = Vth
        self.delta = delta
        self.tau_ref = tau_ref
        self.tau_w = tau_w
        self.a = a
        self.b = b
        self.C = C
        self.gL = gL
        self.EL = EL
        self.V_reset = V_reset
        self.V_sp_th = V_sp_th

        # synaptic parameters
        self.tau_e = tau_e
        self.tau_i = tau_i
        self.E_e = E_e
        self.E_i = E_i

        # other parameters
        self.V_initializer = V_initializer
        self.w_initializer = w_initializer
        self.ge_initializer = ge_initializer
        self.gi_initializer = gi_initializer

    def init_state(self):
        # neuronal variables
        self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape))
        self.w = brainstate.HiddenState(braintools.init.param(self.w_initializer, self.varshape))
        self.t_last_spike = brainstate.HiddenState(braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape))
        self.spike = brainstate.HiddenState(braintools.init.param(lambda s: u.math.zeros(s, bool), self.varshape))

        # synaptic parameters
        self.ge = brainstate.HiddenState(braintools.init.param(self.ge_initializer, self.varshape))
        self.gi = brainstate.HiddenState(braintools.init.param(self.gi_initializer, self.varshape))

    def dV(self, V, w, ge, gi, Iext):
        I = ge * (self.E_e - V) + gi * (self.E_i - V)
        Iext = self.sum_current_inputs(Iext)
        dVdt = (self.gL * self.delta * u.math.exp((V - self.Vth) / self.delta)
                - w + self.gL * (self.EL - V) + I + Iext) / self.C
        return dVdt

    def dw(self, w, V):
        dwdt = (self.a * (V - self.EL) - w) / self.tau_w
        return dwdt

    def update(self, x=0. * u.pA):
        # numerical integration
        ge = brainstate.nn.exp_euler_step(lambda g: -g / self.tau_e, self.ge.value)
        ge = self.sum_delta_inputs(ge, label='ge')
        gi = brainstate.nn.exp_euler_step(lambda g: -g / self.tau_i, self.gi.value)
        gi = self.sum_delta_inputs(gi, label='gi')
        V = brainstate.nn.exp_euler_step(self.dV, self.V.value, self.w.value, self.ge.value, self.gi.value, x)
        V = self.sum_delta_inputs(V, label='V')
        w = brainstate.nn.exp_euler_step(self.dw, self.w.value, self.V.value)
        # spike detection
        t = brainstate.environ.get('t')
        refractory = (t - self.t_last_spike.value) <= self.tau_ref
        V = u.math.where(refractory, self.V.value, V)
        spike = V >= self.V_sp_th
        self.V.value = u.math.where(spike, self.V_reset, V)
        self.w.value = u.math.where(spike, w + self.b, w)
        self.ge.value = ge
        self.gi.value = gi
        self.spike.value = spike
        self.t_last_spike.value = u.math.where(spike, t, self.t_last_spike.value)
        return spike

Utility functions#

def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total):
    t = 0
    dt = brainstate.environ.get_dt()
    num_gap = int(t_gap / dt)
    num_total = int(t_total / dt)
    num_transition = int(t_transition / dt)

    inputs = []
    ramp_up = u.math.linspace(c_low, c_high, num_transition)
    ramp_down = u.math.linspace(c_high, c_low, num_transition)
    plato_base = u.math.ones(num_gap) * c_low
    while t < num_total:
        num_plato = int(brainstate.random.uniform(low=t_min_plato, high=t_max_plato) / dt)
        inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down])
        t += (num_gap + num_transition + num_plato + num_transition)
    return u.math.concatenate(inputs)[:num_total]


def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space):
    # sampling_space: in seconds (no units)
    # signal_time: in seconds (no units)
    # low_cut: in Hz (no units)(band to filter)
    # high_cut: in Hz (no units)(band to filter)

    signal = signal - np.mean(signal)
    width = 5.0  # The desired width in Hz of the transition from pass to stop
    ripple_db = 60.0  # The desired attenuation in the stop band, in dB.
    sampling_rate = 1. / sampling_space
    Nyquist = sampling_rate / 2.

    num_taps, beta = kaiserord(ripple_db, width / Nyquist)
    if num_taps % 2 == 0:
        num_taps = num_taps + 1  # Numtaps must be odd
    taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta),
                  pass_zero=False, scale=True)
    filtered_signal = lfilter(taps, 1.0, signal)
    delay = 0.5 * (num_taps - 1) / sampling_rate  # To corrected to zero-phase
    delay_index = int(np.floor(delay * sampling_rate))
    filtered_signal = filtered_signal[num_taps - 1:]  # taking out the "corrupted" signal
    # correcting the delay and taking out the "corrupted" signal part
    filtered_time = signal_time[num_taps - 1:] - delay
    cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))]

    # --------------------------------------------------------------------------
    # The hilbert transform are very slow when the signal has odd lenght,
    # This part check if the length is odd, and if this is the case it adds a zero in the end
    # of all the vectors related to the filtered Signal:
    if len(filtered_signal) % 2 != 0:  # If the lengh is odd
        tmp1 = filtered_signal.tolist()
        tmp1.append(0)
        tmp2 = filtered_time.tolist()
        tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0])
        tmp3 = cutted_signal.tolist()
        tmp3.append(0)
        filtered_signal = np.asarray(tmp1)
        filtered_time = np.asarray(tmp2)
        cutted_signal = np.asarray(tmp3)
    # --------------------------------------------------------------------------

    ht_filtered_signal = hilbert(filtered_signal)
    envelope = np.abs(ht_filtered_signal)
    phase = np.angle(ht_filtered_signal)  # The phase is between -pi and pi in radians

    return filtered_time, filtered_signal, cutted_signal, envelope, phase


def visualize_simulation_results(
    times, spikes, example_potentials, varied_rates,
    xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None
):
    times = times.to_decimal(u.ms)
    varied_rates = varied_rates.to_decimal(u.Hz)
    example_potentials = {k: v.to_decimal(u.mV) for k, v in example_potentials.items()}

    fig, gs = braintools.visualize.get_figure(7, 1, 1, 12)
    # 1. input firing rate
    ax = fig.add_subplot(gs[0])
    plt.plot(times, varied_rates)
    if xlim is None:
        xlim = (0, times[-1])
    else:
        xlim = (xlim[0].to_decimal(u.ms), xlim[1].to_decimal(u.ms))
    ax.set_xlim(*xlim)
    ax.set_xticks([])
    ax.set_ylabel('External\nRate (Hz)')

    # 2. inhibitory cell rater plot
    ax = fig.add_subplot(gs[1: 3])
    i = 0
    y_ticks = ([], [])
    for key, (sp_matrix, sp_type) in spikes.items():
        iis, sps = np.where(sp_matrix)
        tts = times[iis]
        plt.scatter(tts, sps + i, s=1, label=key)
        y_ticks[0].append(i + sp_matrix.shape[1] / 2)
        y_ticks[1].append(key)
        i += sp_matrix.shape[1]
    ax.set_xlim(*xlim)
    ax.set_xlabel('')
    ax.set_ylabel('Neuron Index')
    ax.set_xticks([])
    ax.set_yticks(*y_ticks)

    # 3. example membrane potential
    ax = fig.add_subplot(gs[3: 5])
    for key, potential in example_potentials.items():
        vs = np.where(spikes[key][0][:, 0], 0, potential)
        plt.plot(times, vs, label=key)
    ax.set_xlim(*xlim)
    ax.set_xticks([])
    ax.set_ylabel('V (mV)')
    ax.legend()

    # 4. LFP
    ax = fig.add_subplot(gs[5:7])
    ax.set_xlim(*xlim)
    t1 = int(t_lfp_start / brainstate.environ.get_dt()) if t_lfp_start is not None else 0
    t2 = int(t_lfp_end / brainstate.environ.get_dt()) if t_lfp_end is not None else len(times)
    times = times[t1: t2]
    lfp = 0
    for sp_matrix, sp_type in spikes.values():
        lfp += braintools.metric.unitary_LFP(times, sp_matrix[t1: t2], sp_type)
    phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(
        lfp, times * 1e-3, 30, 50, brainstate.environ.get_dt() / u.second
    )
    plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP')
    plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)")
    plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope")
    plt.legend(loc='best')
    plt.xlabel('Time (ms)')

    # save or show
    if filename:
        plt.savefig(filename, dpi=500)
    plt.show()

Neuron dynamics simulation#

def simulate_adex_neuron(ax_v, ax_I, pars, title):
    with brainstate.environ.context(dt=0.1 * u.ms):
        # neuron
        adex = brainstate.nn.init_all_states(AdEx(1, **pars))

        def run_step(t, x):
            with brainstate.environ.context(t=t):
                adex.update(x)
                return adex.V.value

        # simulation
        duration = 1.5e3 * u.ms
        times = u.math.arange(0. * u.ms, duration, brainstate.environ.get_dt())
        inputs = get_inputs(0. * u.nA, 0.5 * u.nA, t_transition=50. * u.ms,
                            t_min_plato=500 * u.ms, t_max_plato=500 * u.ms,
                            t_gap=500 * u.ms, t_total=duration)
        vs = brainstate.transform.for_loop(run_step, times, inputs, pbar=brainstate.transform.ProgressBar(10))

        # visualization
        ax_v.plot(times.to_decimal(u.ms), vs.to_decimal(u.mV))
        ax_v.set_title(title)
        ax_v.set_ylabel('V (mV)')
        ax_v.set_xlim(0.4 * u.second / u.ms, 1.2 * u.second / u.ms)

        ax_I.plot(times.to_decimal(u.ms), inputs.to_decimal(u.nA))
        ax_I.set_ylabel('I (nA)')
        ax_I.set_xlabel('Time (ms)')
        ax_I.set_xlim(0.4 * u.second / u.ms, 1.2 * u.second / u.ms)


def simulate_adex_neurons():
    fig, gs = braintools.visualize.get_figure(2, 3, 4, 6)
    simulate_adex_neuron(fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[1, 0]), RS_par, 'Regular Spiking')
    simulate_adex_neuron(fig.add_subplot(gs[0, 1]), fig.add_subplot(gs[1, 1]), FS_par, 'Fast Spiking')
    simulate_adex_neuron(fig.add_subplot(gs[0, 2]), fig.add_subplot(gs[1, 2]), Ch_par, 'Chattering')
    plt.show()
simulate_adex_neurons()
../_images/0353524dc07734f839472f52a0d27405b3ba078f5971c83e7dcd27f7d0d904e8.png

PING Network for Generating Gamma Oscillation#

class PINGNet(brainstate.nn.Module):
    def __init__(self):
        super().__init__()

        self.num_exc = 20000
        self.num_inh = 5000
        self.exc_syn_tau = 1. * u.ms
        self.inh_syn_tau = 7.5 * u.ms
        self.exc_syn_weight = 5. * u.nS
        self.inh_syn_weight = 3.34 * u.nS
        self.ext_weight = 4. * u.nS
        self.delay = 1.5 * u.ms

        # neuronal populations
        RS_par_ = RS_par.copy()
        FS_par_ = FS_par.copy()
        RS_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV)
        FS_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV)
        self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **RS_par_)
        self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **FS_par_)
        self.ext_pop = brainpy.state.PoissonEncoder(self.num_exc)

        # Poisson inputs
        self.ext_to_FS = brainpy.state.DeltaProj(
            comm=brainstate.nn.EventFixedProb(self.num_exc, self.num_inh, 0.02, self.ext_weight),
            post=self.fs_pop,
            label='ge'
        )
        self.ext_to_RS = brainpy.state.DeltaProj(
            comm=brainstate.nn.EventFixedProb(self.num_exc, self.num_exc, 0.02, self.ext_weight),
            post=self.rs_pop,
            label='ge'
        )

        # synaptic projections
        self.RS_to_FS = brainpy.state.DeltaProj(
            self.rs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_exc, self.num_inh, 0.02, self.exc_syn_weight),
            post=self.fs_pop,
            label='ge'
        )
        self.RS_to_RS = brainpy.state.DeltaProj(
            self.rs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_exc, self.num_exc, 0.02, self.exc_syn_weight),
            post=self.rs_pop,
            label='ge'
        )
        self.FS_to_RS = brainpy.state.DeltaProj(
            self.fs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_inh, self.num_exc, 0.02, self.inh_syn_weight),
            post=self.rs_pop,
            label='gi'
        )
        self.FS_to_FS = brainpy.state.DeltaProj(
            self.fs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_inh, self.num_inh, 0.02, self.inh_syn_weight),
            post=self.fs_pop,
            label='gi'
        )

    def update(self, i, t, freq):
        with brainstate.environ.context(t=t, i=i):
            ext_spikes = self.ext_pop(freq)
            self.ext_to_FS(ext_spikes)
            self.ext_to_RS(ext_spikes)

            self.RS_to_RS()
            self.RS_to_FS()

            self.FS_to_RS()
            self.FS_to_FS()

            self.rs_pop()
            self.fs_pop()

            return {
                'FS.V0': self.fs_pop.V.value[0],
                'RS.V0': self.rs_pop.V.value[0],
                'FS.spike': self.fs_pop.spike.value,
                'RS.spike': self.rs_pop.spike.value
            }
def simulate_ping_net():
    with brainstate.environ.context(dt=0.1 * u.ms):
        # inputs
        duration = 6e3 * u.ms
        varied_rates = get_inputs(2. * u.Hz, 3. * u.Hz, 50. * u.ms, 3150 * u.ms, 600 * u.ms, 1e3 * u.ms, duration)

        # network
        net = brainstate.nn.init_all_states(PINGNet())

        # simulation
        times = u.math.arange(0. * u.ms, duration, brainstate.environ.get_dt())
        indices = u.math.arange(0, len(times))
        returns = brainstate.transform.for_loop(net.update, indices, times, varied_rates,
                                              pbar=brainstate.transform.ProgressBar(100))

        # visualization
        visualize_simulation_results(
            times=times,
            spikes={'FS': (returns['FS.spike'], 'inh'),
                    'RS': (returns['RS.spike'], 'exc')},
            example_potentials={'FS': returns['FS.V0'],
                                'RS': returns['RS.V0']},
            varied_rates=varied_rates,
            xlim=(2e3 * u.ms, 3.4e3 * u.ms),
            t_lfp_start=1e3 * u.ms,
            t_lfp_end=5e3 * u.ms
        )
simulate_ping_net()
../_images/e76c464005dd101c1bf9bfaa1e2b253dd35d3f42c548765a61279e8828dbac34.png

ING Network for Generating Gamma Oscillation#

class INGNet(brainstate.nn.Module):
    def __init__(self):
        super().__init__()

        self.num_rs = 20000
        self.num_fs = 4000
        self.num_fs2 = 1000
        self.exc_syn_tau = 5. * u.ms
        self.inh_syn_tau = 5. * u.ms
        self.ext_weight = 0.9 * u.nS
        self.exc_syn_weight = 1. * u.nS
        self.inh_syn_weight = 5. * u.nS
        self.delay = 1.5 * u.ms

        # neuronal populations
        RS_par_ = RS_par.copy()
        FS_par_ = FS_par.copy()
        FS2_par_ = FS_par.copy()
        RS_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV)
        FS_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV)
        FS2_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV)
        self.rs_pop = AdEx(self.num_rs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **RS_par_)
        self.fs_pop = AdEx(self.num_fs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **FS_par_)
        self.fs2_pop = AdEx(self.num_fs2, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **FS2_par_)
        self.ext_pop = brainpy.state.PoissonEncoder(self.num_rs)

        # Poisson inputs
        self.ext_to_FS = brainpy.state.DeltaProj(
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_fs, 0.02, self.ext_weight),
            post=self.fs_pop,
            label='ge'
        )
        self.ext_to_RS = brainpy.state.DeltaProj(
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_rs, 0.02, self.ext_weight),
            post=self.rs_pop,
            label='ge'
        )
        self.ext_to_RS2 = brainpy.state.DeltaProj(
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_fs2, 0.02, self.ext_weight),
            post=self.fs2_pop,
            label='ge'
        )

        # synaptic projections
        self.RS_to_FS = brainpy.state.DeltaProj(
            self.rs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_fs, 0.02, self.exc_syn_weight),
            post=self.fs_pop,
            label='ge'
        )
        self.RS_to_RS = brainpy.state.DeltaProj(
            self.rs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_rs, 0.02, self.exc_syn_weight),
            post=self.rs_pop,
            label='ge'
        )
        self.RS_to_FS2 = brainpy.state.DeltaProj(
            self.rs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_fs2, 0.15, self.exc_syn_weight),
            post=self.fs2_pop,
            label='ge'
        )

        self.FS_to_RS = brainpy.state.DeltaProj(
            self.fs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_fs, self.num_rs, 0.02, self.inh_syn_weight),
            post=self.rs_pop,
            label='gi'
        )
        self.FS_to_FS = brainpy.state.DeltaProj(
            self.fs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_fs, self.num_fs, 0.02, self.inh_syn_weight),
            post=self.fs_pop,
            label='gi'
        )
        self.FS_to_FS2 = brainpy.state.DeltaProj(
            self.fs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_fs, self.num_fs2, 0.03, self.inh_syn_weight),
            post=self.fs2_pop,
            label='gi'
        )

        self.FS2_to_RS = brainpy.state.DeltaProj(
            self.fs2_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_fs2, self.num_rs, 0.15, self.exc_syn_weight),
            post=self.rs_pop,
            label='gi'
        )
        self.FS2_to_FS = brainpy.state.DeltaProj(
            self.fs2_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_fs2, self.num_fs, 0.15, self.exc_syn_weight),
            post=self.fs_pop,
            label='gi'
        )
        self.FS2_to_FS2 = brainpy.state.DeltaProj(
            self.fs2_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_fs2, self.num_fs2, 0.6, self.exc_syn_weight),
            post=self.fs2_pop,
            label='gi'
        )

    def update(self, i, t, freq):
        with brainstate.environ.context(t=t, i=i):
            ext_spikes = self.ext_pop(freq)
            self.ext_to_FS(ext_spikes)
            self.ext_to_RS(ext_spikes)
            self.ext_to_RS2(ext_spikes)

            self.RS_to_RS()
            self.RS_to_FS()
            self.RS_to_FS2()

            self.FS_to_RS()
            self.FS_to_FS()
            self.FS_to_FS2()

            self.FS2_to_RS()
            self.FS2_to_FS()
            self.FS2_to_FS2()

            self.rs_pop()
            self.fs_pop()
            self.fs2_pop()

            return {
                'FS.V0': self.fs_pop.V.value[0],
                'FS2.V0': self.fs2_pop.V.value[0],
                'RS.V0': self.rs_pop.V.value[0],
                'FS.spike': self.fs_pop.spike.value,
                'FS2.spike': self.fs2_pop.spike.value,
                'RS.spike': self.rs_pop.spike.value
            }
def simulate_ing_net():
    with brainstate.environ.context(dt=0.1 * u.ms):
        # inputs
        duration = 6e3 * u.ms
        varied_rates = get_inputs(2. * u.Hz, 3. * u.Hz, 50. * u.ms, 350 * u.ms, 600 * u.ms, 1e3 * u.ms, duration)

        # network
        net = brainstate.nn.init_all_states(INGNet())

        # simulation
        times = u.math.arange(0. * u.ms, duration, brainstate.environ.get_dt())
        indices = u.math.arange(0, len(times))
        returns = brainstate.transform.for_loop(net.update, indices, times, varied_rates,
                                              pbar=brainstate.transform.ProgressBar(100))

        # visualization
        visualize_simulation_results(
            times=times,
            spikes={'FS': (returns['FS.spike'], 'inh'),
                    'FS2': (returns['FS2.spike'], 'inh'),
                    'RS': (returns['RS.spike'], 'exc')},
            example_potentials={'FS': returns['FS.V0'],
                                'FS2': returns['FS2.V0'],
                                'RS': returns['RS.V0']},
            varied_rates=varied_rates,
            xlim=(2e3 * u.ms, 3.4e3 * u.ms),
            t_lfp_start=1e3 * u.ms,
            t_lfp_end=5e3 * u.ms
        )
simulate_ing_net()
../_images/9065da3668de670132410088cf2d75ba6613f4c38e67c98f321a6320b5a80e4c.png

CHING Network for Generating Gamma Oscillation#

class CHINGNet(brainstate.nn.Module):
    def __init__(self):
        super().__init__()

        self.num_rs = 19000
        self.num_fs = 5000
        self.num_ch = 1000
        self.exc_syn_tau = 5. * u.ms
        self.inh_syn_tau = 5. * u.ms
        self.exc_syn_weight = 1. * u.nS
        self.inh_syn_weight1 = 7. * u.nS
        self.inh_syn_weight2 = 5. * u.nS
        self.ext_weight1 = 1. * u.nS
        self.ext_weight2 = 0.75 * u.nS
        self.delay = 1.5 * u.ms

        # neuronal populations
        RS_par_ = RS_par.copy()
        FS_par_ = FS_par.copy()
        Ch_par_ = Ch_par.copy()
        RS_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV)
        FS_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV)
        Ch_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV)
        self.rs_pop = AdEx(self.num_rs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **RS_par_)
        self.fs_pop = AdEx(self.num_fs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **FS_par_)
        self.ch_pop = AdEx(self.num_ch, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **Ch_par_)
        self.ext_pop = brainpy.state.PoissonEncoder(self.num_rs)

        # Poisson inputs
        self.ext_to_FS = brainpy.state.DeltaProj(
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_fs, 0.02, self.ext_weight2),
            post=self.fs_pop,
            label='ge',
        )
        self.ext_to_RS = brainpy.state.DeltaProj(
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_rs, 0.02, self.ext_weight1),
            post=self.rs_pop,
            label='ge',
        )
        self.ext_to_CH = brainpy.state.DeltaProj(
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_ch, 0.02, self.ext_weight1),
            post=self.ch_pop,
            label='ge',
        )

        # synaptic projections
        self.RS_to_FS = brainpy.state.DeltaProj(
            self.rs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_fs, 0.02, self.exc_syn_weight),
            post=self.fs_pop,
            label='ge',
        )
        self.RS_to_RS = brainpy.state.DeltaProj(
            self.rs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_rs, 0.02, self.exc_syn_weight),
            post=self.rs_pop,
            label='ge',
        )
        self.RS_to_Ch = brainpy.state.DeltaProj(
            self.rs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_rs, self.num_ch, 0.02, self.exc_syn_weight),
            post=self.ch_pop,
            label='ge',
        )

        # inhibitory projections
        self.FS_to_RS = brainpy.state.DeltaProj(
            self.fs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_fs, self.num_rs, 0.02, self.inh_syn_weight1),
            post=self.rs_pop,
            label='gi',
        )
        self.FS_to_FS = brainpy.state.DeltaProj(
            self.fs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_fs, self.num_fs, 0.02, self.inh_syn_weight2),
            post=self.fs_pop,
            label='gi',
        )
        self.FS_to_Ch = brainpy.state.DeltaProj(
            self.fs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_fs, self.num_ch, 0.02, self.inh_syn_weight1),
            post=self.ch_pop,
            label='gi',
        )

        # chatter cell projections
        self.Ch_to_RS = brainpy.state.DeltaProj(
            self.ch_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_ch, self.num_rs, 0.02, self.exc_syn_weight),
            post=self.rs_pop,
            label='ge',
        )
        self.Ch_to_FS = brainpy.state.DeltaProj(
            self.ch_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_ch, self.num_fs, 0.02, self.exc_syn_weight),
            post=self.fs_pop,
            label='ge',
        )
        self.Ch_to_Ch = brainpy.state.DeltaProj(
            self.ch_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_ch, self.num_ch, 0.02, self.exc_syn_weight),
            post=self.ch_pop,
            label='ge',
        )

    def update(self, i, t, freq):
        with brainstate.environ.context(i=i, t=t):
            ext_spikes = self.ext_pop(freq)
            self.ext_to_FS(ext_spikes)
            self.ext_to_RS(ext_spikes)
            self.ext_to_CH(ext_spikes)

            self.RS_to_FS()
            self.RS_to_RS()
            self.RS_to_Ch()

            self.FS_to_RS()
            self.FS_to_FS()
            self.FS_to_Ch()

            self.Ch_to_RS()
            self.Ch_to_FS()
            self.Ch_to_Ch()

            self.rs_pop()
            self.fs_pop()
            self.ch_pop()

            return {
                'FS.V0': self.fs_pop.V.value[0],
                'CH.V0': self.ch_pop.V.value[0],
                'RS.V0': self.rs_pop.V.value[0],
                'FS.spike': self.fs_pop.spike.value,
                'CH.spike': self.ch_pop.spike.value,
                'RS.spike': self.rs_pop.spike.value
            }
def simulate_ching_net():
    with brainstate.environ.context(dt=0.1 * u.ms):
        # inputs
        duration = 6e3 * u.ms
        varied_rates = get_inputs(1. * u.Hz, 2. * u.Hz, 50. * u.ms, 150 * u.ms, 600 * u.ms, 1e3 * u.ms, duration)

        # network
        net = brainstate.nn.init_all_states(CHINGNet())

        # simulation
        times = u.math.arange(0. * u.ms, duration, brainstate.environ.get_dt())
        indices = u.math.arange(0, len(times))
        returns = brainstate.transform.for_loop(net.update, indices, times, varied_rates,
                                              pbar=brainstate.transform.ProgressBar(100))

        # visualization
        visualize_simulation_results(
            times=times,
            spikes={'FS': (returns['FS.spike'], 'inh'),
                    'CH': (returns['CH.spike'], 'exc'),
                    'RS': (returns['RS.spike'], 'exc')},
            example_potentials={'FS': returns['FS.V0'],
                                'CH': returns['CH.V0'],
                                'RS': returns['RS.V0']},
            varied_rates=varied_rates,
            xlim=(2e3 * u.ms, 3.4e3 * u.ms),
            t_lfp_start=1e3 * u.ms,
            t_lfp_end=5e3 * u.ms
        )
simulate_ching_net()
../_images/c71e0742bb330072c2efe94da6786b0d4d71894f7a35d01a0b40df377e8e51ce.png

Asynchronous Network State#

class AINet(brainstate.nn.Module):
    def __init__(self):
        super().__init__()

        self.num_exc = 20000
        self.num_inh = 5000
        self.exc_syn_tau = 5. * u.ms
        self.inh_syn_tau = 5. * u.ms
        self.exc_syn_weight = 1. * u.nS
        self.inh_syn_weight = 5. * u.nS
        self.delay = 1.5 * u.ms
        self.ext_weight = 1.0 * u.nS

        # neuronal populations
        RS_par_ = RS_par.copy()
        FS_par_ = FS_par.copy()
        RS_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV)
        FS_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV)
        self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **FS_par_)
        self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **RS_par_)
        self.ext_pop = brainpy.state.PoissonEncoder(self.num_exc)

        # Poisson inputs
        self.ext_to_FS = brainpy.state.DeltaProj(
            comm=brainstate.nn.EventFixedProb(self.num_exc, self.num_inh, 0.02, self.ext_weight),
            post=self.fs_pop,
            label='ge'
        )
        self.ext_to_RS = brainpy.state.DeltaProj(
            comm=brainstate.nn.EventFixedProb(self.num_exc, self.num_exc, 0.02, self.ext_weight),
            post=self.rs_pop,
            label='ge'
        )

        # synaptic projections
        self.RS_to_FS = brainpy.state.DeltaProj(
            self.rs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_exc, self.num_inh, 0.02, self.exc_syn_weight),
            post=self.fs_pop,
            label='ge'
        )
        self.RS_to_RS = brainpy.state.DeltaProj(
            self.rs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_exc, self.num_exc, 0.02, self.exc_syn_weight),
            post=self.rs_pop,
            label='ge'
        )
        self.FS_to_FS = brainpy.state.DeltaProj(
            self.fs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_inh, self.num_inh, 0.02, self.inh_syn_weight),
            post=self.fs_pop,
            label='gi'
        )
        self.FS_to_RS = brainpy.state.DeltaProj(
            self.fs_pop.prefetch('spike').delay.at(self.delay),
            comm=brainstate.nn.EventFixedProb(self.num_inh, self.num_exc, 0.02, self.inh_syn_weight),
            post=self.rs_pop,
            label='gi'
        )

    def update(self, i, t, freq):
        with brainstate.environ.context(t=t, i=i):
            ext_spikes = self.ext_pop(freq)
            self.ext_to_FS(ext_spikes)
            self.ext_to_RS(ext_spikes)
            self.RS_to_RS()
            self.RS_to_FS()
            self.FS_to_FS()
            self.FS_to_RS()
            self.rs_pop()
            self.fs_pop()
            return {
                'FS.V0': self.fs_pop.V.value[0],
                'RS.V0': self.rs_pop.V.value[0],
                'FS.spike': self.fs_pop.spike.value,
                'RS.spike': self.rs_pop.spike.value
            }
def simulate_ai_net():
    with brainstate.environ.context(dt=0.1 * u.ms):
        # inputs
        duration = 2e3 * u.ms
        varied_rates = get_inputs(2. * u.Hz, 2. * u.Hz, 50. * u.ms, 150 * u.ms, 600 * u.ms, 1e3 * u.ms, duration)

        # network
        net = brainstate.nn.init_all_states(AINet())

        # simulation
        times = u.math.arange(0. * u.ms, duration, brainstate.environ.get_dt())
        indices = u.math.arange(0, len(times))
        returns = brainstate.transform.for_loop(net.update, indices, times, varied_rates, pbar=brainstate.transform.ProgressBar(100))

        # # spike raster plot
        # spikes = returns['FS.spike']
        # fig, gs = braintools.visualize.get_figure(1, 1, 4., 5.)
        # fig.add_subplot(gs[0, 0])
        # times2 = times.to_decimal(u.ms)
        # t_indices, n_indices = u.math.where(spikes)
        # plt.scatter(times2[t_indices], n_indices, s=1, c='k')
        # plt.xlabel('Time (ms)')
        # plt.ylabel('Neuron index')
        # plt.title('Spike raster plot')
        # plt.show()

        # visualization
        visualize_simulation_results(
            times=times,
            spikes={'FS': (returns['FS.spike'], 'inh'),
                    'RS': (returns['RS.spike'], 'exc')},
            example_potentials={'FS': returns['FS.V0'],
                                'RS': returns['RS.V0']},
            varied_rates=varied_rates
        )
simulate_ai_net()
../_images/5528263189001e2995fb66138f328d6038f7aaf5a926fdcd0575ee394adf2278.png