Simulating whole-brain Drosophila spiking model with brainstate#

Colab Open in Kaggle

Reproducing the simulation in the paper of:

Overview#

This notebook builds a whole-brain spiking network for Drosophila using brainstate, driven by the FlyWire-derived connectome and stimulus protocols that mirror the paper.

What you will find here:

  • Connectome ingestion (FlyWire 630 mirror) and neuron indexing.

  • A leaky integrate-and-fire (LIF) population model with event-driven synapses, delays, and Poisson drive.

  • Convenience helpers to run experiments, silence subsets, and aggregate activity.

  • Code that reproduces/illustrates figure-style experiments (e.g., GRN/JON activation and MN/DN responses).

Prerequisites and runtime notes:

  • JAX runs on CPU by default; a GPU accelerates long runs but is not required.

  • The first build of sparse connectivity can take time proportional to the size of the edge list. Subsequent runs are faster due to caching.

Navigation:

  • Start with the Connectome and Model sections, then use the Example Trial to sanity-check the setup.

  • The later sections script specific figure-like experiments (sugar, water, bitter GRN; JON subtypes; MN9/DN readouts).

import itertools
import os
import pickle
from pathlib import Path
from typing import Callable, Union, Sequence, Dict, Optional

import brainpy.state
import brainevent
import brainstate
import braintools
import brainunit as u
import jax
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

Connectome dataset#

This section fetches the FlyWire 630 connectome mirror via kagglehub and prepares the neuron index set used throughout.

Data files used by this notebook:

  • 2023_03_23_completeness_630_final.csv: canonical list of FlyWire neuron IDs (these define the row/column ordering of the network).

  • 2023_03_23_connectivity_630_final.parquet: long-format edge list with at least the columns Presynaptic_Index, Postsynaptic_Index, and Excitatory x Connectivity.

Implementation details:

  • We construct a CSR sparse matrix from the edge list, sorted by presynaptic index for efficient iteration.

  • Only the excitatory connectivity term is used to build the weight matrix; the scalar w_syn later rescales these values into mV.

  • The returned data_path points to the locally cached dataset so repeated runs avoid re-downloading.

import kagglehub

# Download latest version of Flywire 630 dataset from Kaggle
data_path = kagglehub.dataset_download("oujago/drosophila-630-data")

Whole-brain Spiking Model#

We model each neuron as a single-compartment LIF unit receiving event-based synaptic input and optional external Poisson drive. State variables:

  • v membrane potential (mV), relaxing to v_rest with time constant tau_m.

  • g a lumped synaptic drive (mV), relaxing with time constant tau_syn and incremented by inputs.

  • t_ref tracks refractory timing; within tau_ref, updates are held constant.

Numerics and spikes:

  • Dynamics are integrated with an exponential Euler step (brainstate.nn.exp_euler_step) for stability on reasonable dt.

  • Spiking uses a smooth surrogate (braintools.surrogate.ReluGrad), which keeps code differentiable; we only simulate (no training) here.

  • On a spike, v is reset to v_reset, g is cleared for that step, and t_ref is stamped.

Network wiring:

  • Connectivity comes from a CSR matrix over the FlyWire index set and is applied via brainstate.nn.SparseLinear.

  • A fixed synaptic delay (default 1.8 ms) is modeled using brainstate.nn.Delay and brainevent event arrays.

  • External drive is applied to selected neurons by brainstate.nn.poisson_input with a per-event weight scaled by w_syn * 250 to balance drive.

  • Sets of neurons can be silenced through a boolean mask; silencing gates their outgoing spikes before they propagate.

Outputs and units:

  • We record spike counts and convert to firing rates (Hz) by dividing by the simulated duration in seconds.

  • All quantities carry physical units via brainunit (e.g., ms, mV, Hz) to help prevent unit mismatches.

class Population(brainpy.state.Neuron):
    """
    A population of neurons with leaky integrate-and-fire dynamics.

    The dynamics of the neurons are given by the following equations::

       dv/dt = (v_0 - v + g) / t_mbr : volt (unless refractory)
       dg/dt = -g / tau               : volt (unless refractory)


    Args:
      size: The number of neurons in the population.
      v_rest: The resting potential of the neurons.
      v_reset: The reset potential of the neurons after a spike.
      v_th: The threshold potential of the neurons for spiking.
      tau_m: The membrane time constant of the neurons.
      tau_syn: The synaptic time constant of the neurons.
      tau_ref: The refractory period of the neurons.
      spk_fun: The spike function of the neurons.
      name: The name of the population.

    """

    def __init__(
        self,
        size: brainstate.typing.Size,
        v_rest: u.Quantity = -52 * u.mV,  # resting potential
        v_reset: u.Quantity = -52 * u.mV,  # reset potential after spike
        v_th: u.Quantity = -45 * u.mV,  # potential threshold for spiking
        tau_m: u.Quantity = 20 * u.ms,  # membrane time constant
        # Jürgensen et al https://doi.org/10.1088/2634-4386/ac3ba6
        tau_syn: u.Quantity = 5 * u.ms,  # synaptic time constant
        # Lazar et al https://doi.org/10.7554/eLife.62362
        tau_ref: u.Quantity = 2.2 * u.ms,  # refractory period
        spk_fun: Callable = braintools.surrogate.ReluGrad(),  # spike function
        name: str = None,
    ):
        super().__init__(size, name=name)

        self.v_rest = v_rest
        self.v_reset = v_reset
        self.v_th = v_th
        self.tau_m = tau_m
        self.tau_syn = tau_syn
        self.tau_ref = u.math.full(self.varshape, tau_ref)
        self.spk_fun = spk_fun

    def init_state(self, batch_size=None):
        self.v = brainstate.ShortTermState(
            braintools.init.param(braintools.init.Constant(self.v_rest), self.varshape, batch_size))
        self.g = brainstate.ShortTermState(
            braintools.init.param(braintools.init.Constant(0.0 * u.mV), self.varshape, batch_size))
        self.t_ref = brainstate.ShortTermState(
            braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size))
        self.spike_count = brainstate.ShortTermState(
            braintools.init.param(braintools.init.Constant(0.), self.varshape, batch_size))

    def reset_state(self, batch_size=None):
        self.v.value = u.math.ones(self.varshape) * self.v_rest
        self.g.value = u.math.zeros(self.varshape) * u.mV
        self.t_ref.value = u.math.full(self.varshape, -1e7 * u.ms)

    def get_spike(self, v=None):
        v = self.v.value if v is None else v
        return self.spk_fun((v - self.v_th) / (20. * u.mV))

    def update(self, x):
        t = brainstate.environ.get('t')

        # numerical integration
        dv = lambda v, t, g: (self.v_rest - v + g) / self.tau_m
        dg = lambda g, t: -g / self.tau_syn
        v = brainstate.nn.exp_euler_step(dv, self.v.value, t, self.g.value)
        g = brainstate.nn.exp_euler_step(dg, self.g.value, t)
        g += x  # external input current
        v = self.sum_delta_inputs(v)

        # refractory period
        ref = (t - self.t_ref.value) <= self.tau_ref
        v = u.math.where(ref, self.v.value, v)
        g = u.math.where(ref, self.g.value, g)

        # spikes
        spk = self.get_spike(v)
        self.spike_count.value += spk

        # update states
        self.v.value = spk * (self.v_reset - v) + v
        self.g.value = g - spk * g
        self.t_ref.value = u.math.where(spk, t, self.t_ref.value)
        return spk
class Network(brainstate.nn.Module):
    def __init__(
        self,
        path_neu: Union[Path, str],
        path_syn: Union[Path, str],
        neuron_to_excite: Sequence[Dict] = (),
        neuron_to_inhibit: Sequence[Dict] = (),
        w_syn: u.Quantity = 0.275 * u.mV,
        neuron_params: Dict = None,
    ):
        super().__init__()

        self.path_neu = Path(path_neu)
        self.path_syn = Path(path_syn)

        # neuron ids
        flywire_ids = pd.read_csv(self.path_neu, index_col=0)
        self.n_neuron = len(flywire_ids)
        self._flyid2i = {f: i for i, f in enumerate(flywire_ids.index)}
        self._i2flyid = {i: f for i, f in enumerate(flywire_ids.index)}

        # neurons ids to excite
        self.neuron_to_excite = tuple([
            {
                'indices': (
                    u.math.asarray(exc['indices'])
                    if 'indices' in exc else
                    np.asarray([self._flyid2i[id_] for id_ in exc['ids']])
                ),  # neuron ids to excite
                'rate': exc['rate'],  # Poisson rate
                'w_syn': w_syn.clone() * 250,  # synaptic weight, 250 is the scaling factor for Poisson synapse
            }
            for exc in neuron_to_excite
        ])
        # neuron ids to inhibit
        neuron_to_inhibit = [
            (
                inh['indices']
                if 'indices' in inh else
                [self._flyid2i[i] for i in inh['ids']]
            )
            for inh in neuron_to_inhibit
        ]
        self.neuron_mask = jax.numpy.ones(self.n_neuron, dtype=bool)
        if len(neuron_to_inhibit) > 0:
            for indices in neuron_to_inhibit:
                self.neuron_mask = self.neuron_mask.at[indices].set(False)

        # neuronal and synaptic dynamics
        neuron_params = neuron_params or {}
        self.pop = Population(size=self.n_neuron, **neuron_params)

        # delay for changes in post-synaptic neuron
        # Paul et al 2015 doi: https://doi.org/10.3389/fncel.2015.00029
        self.delay = brainstate.nn.Delay(
            jax.ShapeDtypeStruct(self.pop.varshape, brainstate.environ.dftype()),
            entries={'D': 1.8 * u.ms}
        )

        # synapses: CSR connectivity matrix
        flywire_conns = pd.read_parquet(self.path_syn)
        i_pre = flywire_conns.loc[:, 'Presynaptic_Index'].values
        i_post = flywire_conns.loc[:, 'Postsynaptic_Index'].values
        weight = flywire_conns.loc[:, 'Excitatory x Connectivity'].values
        sort_indices = np.argsort(i_pre)
        i_pre = i_pre[sort_indices]
        i_post = i_post[sort_indices]
        weight = weight[sort_indices]

        values, counts = np.unique(i_pre, return_counts=True)
        indptr = np.zeros(self.n_neuron + 1, dtype=int)
        indptr[values + 1] = counts
        indptr = np.cumsum(indptr)
        indices = i_post

        self.conn = brainstate.nn.SparseLinear(
            brainevent.CSR(
                (weight * w_syn, indices, indptr),
                shape=(self.n_neuron, self.n_neuron)
            ),
            b_init=None,
        )

        # Poisson input
        for exc in self.neuron_to_excite:
            self.pop.tau_ref[exc['indices']] = 0. * u.ms  # set refractory period to 0.5 ms

    def update(self, *args, **kwargs):
        # excite neurons
        for exc in self.neuron_to_excite:
            brainpy.state.poisson_input(
                freq=exc['rate'],
                num_input=1,
                weight=exc['w_syn'],
                target=self.pop.v,
                indices=exc['indices'],
            )

        # delayed spikes
        pre_spk = self.delay.at('D')

        # inhibit neurons
        if len(self.neuron_mask):
            pre_spk = pre_spk * self.neuron_mask

        # compute recurrent connections and update neurons
        spk = self.pop(self.conn(brainevent.EventArray(pre_spk)))

        # update delay spikes
        self.delay.update(spk)
        return spk

    def step_run(self, i, ret_val: str = 'none'):
        with brainstate.environ.context(t=i * brainstate.environ.get_dt(), i=i):
            spk = self.update()
            if ret_val == 'spike':
                return spk
            elif ret_val == 'voltage':
                return self.pop.v.value
            elif ret_val == 'conductance':
                return self.pop.g.value
            elif ret_val == 'spike_count':
                return self.pop.spike_count.value
            elif ret_val == 'none':
                return None
            else:
                raise ValueError('ret_val must be "spike", "voltage", "conductance", or "spike_count"')

Utility Functions#

This section provides helpers used across experiments:

  • run_one_exp(...): build the network with specified excitation/inhibition, run for a duration, and return per-neuron firing rates.

  • flywire_ids(path): load the canonical neuron ID ordering.

  • ndarray_to_dataframe(...): reshape multi-dimensional arrays into tidy Pandas DataFrames for plotting/analysis.

Later sections also define targeted wrappers (e.g., which_*, *_activate_*) that bundle stimulus configurations for the figure-style analyses.

def run_one_exp(
    neurons_to_excite: Sequence[dict],
    neuron_to_inhibit: Sequence = (),
    duration: brainstate.typing.ArrayLike = 1000 * u.ms,
    dt: brainstate.typing.ArrayLike = 0.1 * u.ms,
    pbar: int = 1000,
):
    with brainstate.environ.context(dt=dt):
        indices = np.arange(int(duration / brainstate.environ.get_dt()))

        assert isinstance(neurons_to_excite, (tuple, list)), 'neurons_to_excite must be a list or tuple'
        for exc in neurons_to_excite:
            assert isinstance(exc, dict), 'neurons_to_excite must be a list of dictionaries'
            assert 'ids' in exc or 'indices' in exc, 'neurons_to_excite must have a key "ids" or "indices"'
            assert 'rate' in exc, 'neurons_to_excite must have a key "rate"'
        assert isinstance(neuron_to_inhibit, (tuple, list)), 'neuron_to_inhibit must be a list or tuple'

        net = Network(
            path_neu=os.path.join(data_path, './2023_03_23_completeness_630_final.csv'),
            path_syn=os.path.join(data_path, './2023_03_23_connectivity_630_final.parquet'),
            neuron_to_excite=neurons_to_excite,
            neuron_to_inhibit=neuron_to_inhibit,
        )
        brainstate.nn.init_all_states(net)
        brainstate.transform.for_loop(
            net.step_run,
            indices,
            pbar=pbar
        )
        return net.pop.spike_count.value / duration.to_decimal(u.second)
def flywire_ids(neu_path: str | Path):
    df = pd.read_csv(neu_path, index_col=0)
    return np.asarray(df.index)
def ndarray_to_dataframe(
    arr,
    dim_names: Sequence,
    axis_names: Optional[Sequence[str]] = None,
    column_axis: Optional[int] = None,
    row_axis: Optional[int] = None
):
    """
    Convert a multidimensional NumPy array to a two-dimensional Pandas DataFrame.

    Parameters:
    - arr (numpy.ndarray): Input multi-dimensional array.
    - dim_names (list of list of str): List of names for each dimension.
    - column_axis (int): Specifies which dimension to use as DataFrame columns.
    - row_axis (int): Specifies which dimension to use as DataFrame rows, the remaining
        dimensions will be flattened into columns.

    Returns:
    - pd.DataFrame: Converted two-dimensional DataFrame.
    """

    # Validate input
    if not isinstance(arr, (np.ndarray, jax.Array)):
        raise TypeError("Input must be a NumPy array.")

    if not isinstance(dim_names, list) or len(dim_names) != arr.ndim:
        raise ValueError("dim_names must be a list with the same length "
                         "as the number of dimensions of the array.")
    for i, names in enumerate(dim_names):
        if len(names) != arr.shape[i]:
            raise ValueError(f"Length of dim_names[{i}] must be equal to "
                             f"the size of the corresponding dimension.")

    # Get indices for all dimensions
    axes = list(range(arr.ndim))

    if column_axis is None:
        assert row_axis is not None, 'row_axis must be specified if column_axis is None'

        if not isinstance(row_axis, int):
            if row_axis < 0:
                row_axis = row_axis + arr.ndim
            if not (0 <= row_axis < arr.ndim):
                raise ValueError("row_axis must be a valid dimension index.")

        # Separate row dimension and column dimensions
        row_dim = row_axis
        column_dims = axes[:row_dim] + axes[row_dim + 1:]

        # Get row labels
        row_labels = dim_names[row_dim]

        # Get labels for each column dimension
        column_dim_names = [dim_names[dim] for dim in column_dims]

        # Generate all combinations of column labels
        if column_dim_names:
            column_tuples = list(itertools.product(*column_dim_names))
            # Convert tuples to string labels, e.g., ('Feature_X', 'Measurement_I') -> 'Feature_X_Measurement_I'
            column_labels = ['_'.join(col) for col in column_tuples]
        else:
            # If there is only one dimension as rows, there are no columns
            column_labels = []

        # Rearrange array axes to move row axis to the first position,
        # keeping the order of the remaining dimensions
        reordered_axes = [row_dim] + column_dims
        arr_reordered = np.transpose(arr, axes=reordered_axes)

        # Get new shape
        new_shape = arr_reordered.shape
        num_rows = new_shape[0]
        num_cols = np.prod(new_shape[1:]).astype(int) if len(new_shape) > 1 else 1

        # Flatten all dimensions except the row dimension into columns
        arr_flat = arr_reordered.reshape(num_rows, num_cols)

        # If there are multiple column dimensions, generate combined
        # column labels; otherwise, use a single column label
        if len(column_dims) > 0:
            if len(column_dims) == 1:
                # Only one column dimension, directly use its labels
                column_labels = dim_names[column_dims[0]]
            else:
                # Multiple column dimensions, combine labels
                column_labels = [
                    '_'.join([dim_names[dim][i] for dim, i in zip(column_dims, idx)])
                    for idx in itertools.product(*[range(len(dim_names[dim])) for dim in column_dims])
                ]

        # Create row index
        index = pd.Index(row_labels, name=dim_names[row_dim][0] if len(dim_names[row_dim]) > 0 else f"dim_{row_dim}")

        # Create column index
        if column_labels:
            columns = column_labels
        else:
            # If there are no column labels, create a default single column
            columns = ['Value']
            arr_flat = arr_flat.flatten()

    else:
        assert row_axis is None, 'row_axis must be None if column_axis is specified'

        if not isinstance(column_axis, int):
            if column_axis < 0:
                column_axis = column_axis + arr.ndim
            if not (0 <= column_axis < arr.ndim):
                raise ValueError("column_axis must be a valid dimension index.")

        # Separate column dimension and row dimensions
        column_dim = column_axis
        row_dims = axes[:column_dim] + axes[column_dim + 1:]

        # Get column labels
        columns = dim_names[column_dim]

        # Get labels for each row dimension
        row_dim_names = [dim_names[dim] for dim in row_dims]

        # Generate all combinations of row labels
        row_tuples = list(itertools.product(*row_dim_names))

        # Rearrange array axes to move column axis to the last position
        reordered_axes = row_dims + [column_dim]
        arr_reordered = arr.transpose(reordered_axes)

        # Calculate number of rows and columns
        num_rows = arr_reordered.shape[:-1]
        num_cols = arr_reordered.shape[-1]

        # Flatten all dimensions except the column dimension
        arr_flat = arr_reordered.reshape(-1, num_cols)

        # Create row index, use MultiIndex if there are multiple row dimensions
        if len(row_dims) > 1:
            if axis_names is None:
                names = [f"dim_{dim}" for dim in row_dims]
            else:
                assert len(axis_names) == arr.ndim, 'axis_names must have the same length as the number of dimensions'
                names = [axis_names[dim] for dim in range(arr.ndim)]
            index = pd.MultiIndex.from_tuples(row_tuples, names=names)
        else:
            index = pd.Index(row_tuples, )

    # Create DataFrame
    df = pd.DataFrame(arr_flat, index=index, columns=columns)
    return df

Model settings#

Global parameters used by the examples that follow:

  • n_trial: how many repeated runs to average over for robustness.

  • duration: simulation length per run (with dt typically set in context, e.g., 0.1 ms).

  • all_neuron_ids and flywire_id_to_index: mappings between FlyWire IDs and zero-based indices into the state vectors.

n_trial = 5
duration = 100. * u.ms

all_neuron_ids = flywire_ids(os.path.join(data_path, './2023_03_23_completeness_630_final.csv'))
flywire_id_to_index = {f: i for i, f in enumerate(all_neuron_ids)}

Example Trial#

A minimal end-to-end run that drives a subset of sugar-sensing GRNs (right hemisphere) with a Poisson process and visualizes the resulting spikes across the whole network.

How to adapt:

  • Replace neu_sugar with another neuron list (e.g., water or bitter GRNs, JON subtypes, or arbitrary IDs).

  • Tweak the input rate (in Hz), the duration, or neuron parameters (e.g., tau_m, tau_syn, w_syn) to explore sensitivity.

  • Use ret_val='spike'|'voltage'|'conductance' in step_run to capture different observables.

# define inputs
# list of the labellar sugar-sensing gustatory receptor neurons on right hemisphere
neu_sugar = [
    720575940624963786, 720575940630233916, 720575940637568838, 720575940638202345, 720575940617000768,
    720575940630797113, 720575940632889389, 720575940621754367, 720575940621502051, 720575940640649691,
    720575940639332736, 720575940616885538, 720575940639198653, 720575940620900446, 720575940617937543,
    720575940632425919, 720575940633143833, 720575940612670570, 720575940628853239, 720575940629176663,
    720575940611875570,
]
def example_to_run(duration=1000 * u.ms):
    with brainstate.environ.context(dt=0.1 * u.ms):
        net = Network(
            path_neu=os.path.join(data_path, './2023_03_23_completeness_630_final.csv'),
            path_syn=os.path.join(data_path, './2023_03_23_connectivity_630_final.parquet'),
            neuron_to_excite=[
                # sugar neuron
                {
                    'ids': neu_sugar,
                    'rate': 150 * u.Hz
                },
                # other neuron groups
            ],
        )
        brainstate.nn.init_all_states(net)

        indices = np.arange(int(duration / brainstate.environ.get_dt()))
        times = indices * brainstate.environ.get_dt()
        spks = brainstate.transform.for_loop(lambda i: net.step_run(i, ret_val='spike'), indices, pbar=1000)

    fig, gs = braintools.visualize.get_figure(1, 1, 5, 5)
    ax = fig.add_subplot(gs[0, 0])
    i_times, i_indices = np.where(spks)
    plt.scatter(times[i_times], i_indices, s=1)
    plt.xlabel('Time [ms]')
    plt.ylabel('Neuron index')
    plt.ylim(0, net.n_neuron)
    plt.xlim(0 * u.ms, duration)
    plt.show()
example_to_run()
../_images/b545e04a627b8c2dc5426f1f02c808ace026cf42380416e4927c3841ab38a897.png

Figure 1#

The following cells script experiments that mirror Figure 1 panels in the paper. For each panel we:

  • Define the stimulated neuron set(s) and their Poisson drive rates.

  • Optionally silence subsets (to assess causal contribution).

  • Run the simulation and aggregate whole-brain activity or specific readouts (e.g., MN9).

Note: Exact numerical replication may differ due to model simplifications (single-compartment LIF, weight rescaling) and environment details (JAX device, random seeds). The qualitative patterns and causal trends should match.

fig1_path_res = Path('./figure_1')
os.makedirs(fig1_path_res, exist_ok=True)

Figure 1 D: Whole Brain Activity when Activating Sugar GRN#

# frequencies
freqs = np.arange(10, 201, 20)
def whole_brain_activation():
    @brainstate.transform.vmap2
    def run_with_freq(freq_):
        return run_one_exp(
            neurons_to_excite=[{'ids': neu_sugar, 'rate': freq_ * u.Hz}],
            duration=duration, pbar=10,
        )

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def run_simulations():
        return run_with_freq(freqs)

    if not os.path.exists(fig1_path_res / 'firing_rate.csv'):
        rates = run_simulations()

        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                [f'{freq} Hz' for freq in freqs],
                all_neuron_ids
            ],
            column_axis=0
        )
        df.to_csv(fig1_path_res / 'firing_rate.csv')

        df_rate = df.loc['190 Hz'].mean(axis=1)
        df_rate_std = df.loc['190 Hz'].std(axis=1)

        # save 200 neurons for next simulation
        id_top200 = df_rate.sort_values(ascending=False).index[:200]

        print(df_rate.sort_values(ascending=False))

        with open(fig1_path_res / 'fig_1d_id_top200.pickle', 'wb') as f:
            pickle.dump(id_top200, f)
whole_brain_activation()

Figure 1 E: Which neurons activate MN9?#

# define inputs
# flywire ID for MN9
id_mn9 = 720575940660219265

index_id_mn9 = flywire_id_to_index[id_mn9]

# identify the top 200 neurons that respond due to 200 Hz sugar firing (Figure 1D)
with open(fig1_path_res / 'fig_1d_id_top200.pickle', 'rb') as f:
    id_top200 = pickle.load(f)
# neuron indices
indices = np.asarray([flywire_id_to_index[id_] for id_ in id_top200])
indices = np.expand_dims(indices, axis=1)
def which_neuron_activate_MN9(freqs):
    @brainstate.transform.vmap2(in_axes=(0, None))
    def run_with_freq(freq_, indices_):
        assert indices_.ndim == 1 and indices_.size == 1
        rate = run_one_exp(
            neurons_to_excite=[{'indices': indices_, 'rate': freq_ * u.Hz}],
            duration=duration,
            pbar=10
        )
        return rate[index_id_mn9]

    def run_with_indices(indices_):
        return run_with_freq(freqs, indices_)

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        # due to the limited memory of the device, we choose using "brainstate.transform.map" instead of "brainstate.transform.vmap2"

        # return bst.transform.vmap2(run_with_indices)(bst.random.split_key(len(indices)), indices)
        return brainstate.transform.map(run_with_indices, indices, batch_size=20)

    path = fig1_path_res / f'which_neuron_activate_NM9_firing_rate.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                id_top200,
                [f'{freq} Hz' for freq in freqs],
            ],
            column_axis=0
        )
        df.to_csv(path)
which_neuron_activate_MN9(
    np.asarray([25, 50, 75, 100, 125, 150, 175, 200])
)
df = pd.read_csv(fig1_path_res / f'which_neuron_activate_NM9_firing_rate.csv')

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
pivot_df = df.pivot(index='dim_1', columns='dim_2', values='mean')
pivot_df.sort_index(ascending=False, inplace=True)
pivot_df = pivot_df.fillna(0)
pivot_df = pivot_df.reindex([f'{f} Hz' for f in [25, 50, 75, 100, 125, 150, 175, 200]], axis=1)

sns.heatmap(pivot_df)
plt.show()
../_images/2b7d6317a454d39d8ae4a7af251e22923c3988972e0a0a34200f017690dbfca9.png

Figure 1F: Silencing#

# neuron indices
silencing_indices = np.asarray([flywire_id_to_index[id_] for id_ in id_top200])
silencing_indices = np.expand_dims(silencing_indices, axis=1)
def slice_neuron_and_activate_MN9(freqs):
    @brainstate.transform.vmap2(in_axes=(0, None))
    def run_with_freq(freq_, indices_):
        rate = run_one_exp(
            neurons_to_excite=[{'ids': neu_sugar, 'rate': freq_ * u.Hz}],
            neuron_to_inhibit=[{'indices': indices_}],
            duration=duration,
            pbar=10
        )
        return rate[index_id_mn9]

    def run_with_indices(indices_):
        return run_with_freq(freqs, indices_)

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        # Due to the limited memory of the device, we choose using "brainstate.transform.map" instead of "brainstate.transform.vmap2"
        # Decrease "batch_size" when out of memory

        # return bst.transform.vmap2(run_with_indices)(bst.random.split_key(len(silencing_indices)), silencing_indices)
        return brainstate.transform.map(run_with_indices, silencing_indices, batch_size=20)

    path = fig1_path_res / f'sugarR-silencing_which_neuron_firing_rate.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                id_top200,
                [f'{freq} Hz' for freq in freqs],
            ],
            column_axis=0
        )
        df.to_csv(path)
slice_neuron_and_activate_MN9(
    np.asarray([50, 60, 70, 80, 90, 100, 110, 120])
)
df = pd.read_csv(fig1_path_res / f'sugarR-silencing_which_neuron_firing_rate.csv')

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
pivot_df = df.pivot(index='dim_1', columns='dim_2', values='mean')
pivot_df = pivot_df.fillna(0)
pivot_df = pivot_df.reindex([f'{f} Hz' for f in [50, 60, 70, 80, 90, 100, 110, 120]], axis=1)

sns.heatmap(pivot_df)
plt.show()
../_images/6dedea9eeb4e532d110bdbc470daab66100efc9ce8767d8191de4604261eb783.png

Figure 2: Which Cell Type Activates MN9?#

# output path
fig2_path_res = Path('./figure_2')
os.makedirs(fig2_path_res, exist_ok=True)
import requests
# define input
ids_two_mn9 = [720575940660219265, 720575940645521262]  # left and right

index_id_mn9 = np.asarray(
    [flywire_id_to_index[ids_two_mn9[0]],
     flywire_id_to_index[ids_two_mn9[1]]]
)

url = 'https://raw.githubusercontent.com/chaobrain/drosophila_whole_brain_snn_simulation/main/sez_neurons.pickle'
types_and_ids = pickle.loads(requests.get(url).content)
def which_cell_type_activates_MN9(cell_ids, type_name, freqs):
    @brainstate.transform.vmap2
    def run_with_freq(freq_):
        rate = run_one_exp(
            neurons_to_excite=[{'ids': cell_ids, 'rate': freq_ * u.Hz}],
            duration=duration,
            pbar=10
        )
        return rate[index_id_mn9]

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        return run_with_freq(freqs)

    path = fig2_path_res / f'cell_type={type_name}.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                [f'{freq} Hz' for freq in freqs],
                ids_two_mn9
            ],
            column_axis=0
        )
        df.to_csv(path)
for cell_type_name, neurons in types_and_ids.items():
    print('Activating MN9 with cell type', cell_type_name)
    which_cell_type_activates_MN9(
        neurons,
        cell_type_name,
        np.asarray([50, 100, 150, 200])
    )
Activating MN9 with cell type aDT6
Activating MN9 with cell type amulet
Activating MN9 with cell type aSG1
Activating MN9 with cell type aSG7
Activating MN9 with cell type aster
Activating MN9 with cell type Asteroid
Activating MN9 with cell type bamboo
Activating MN9 with cell type basket
Activating MN9 with cell type bluebell
Activating MN9 with cell type bobber
Activating MN9 with cell type box
Activating MN9 with cell type bract
Activating MN9 with cell type bridle
Activating MN9 with cell type brontosaraus
Activating MN9 with cell type broom
Activating MN9 with cell type buddy
Activating MN9 with cell type clavicle
Activating MN9 with cell type coy
Activating MN9 with cell type crab
Activating MN9 with cell type cradle
Activating MN9 with cell type damsel
Activating MN9 with cell type diatom
Activating MN9 with cell type doublescoop
Activating MN9 with cell type earmuff
Activating MN9 with cell type eiffel
Activating MN9 with cell type Fdg
Activating MN9 with cell type fluff
Activating MN9 with cell type FMIn
Activating MN9 with cell type foxglove
Activating MN9 with cell type fudog
Activating MN9 with cell type G2N_1
Activating MN9 with cell type gallinule
Activating MN9 with cell type gazebo
Activating MN9 with cell type genie
Activating MN9 with cell type gnome
Activating MN9 with cell type good_dog
Activating MN9 with cell type gumdrop
Activating MN9 with cell type handle
Activating MN9 with cell type handup
Activating MN9 with cell type haystack
Activating MN9 with cell type horntail
Activating MN9 with cell type horseshoe
Activating MN9 with cell type horn
Activating MN9 with cell type hound
Activating MN9 with cell type hyacinth
Activating MN9 with cell type justice
Activating MN9 with cell type kelp
Activating MN9 with cell type kitty
Activating MN9 with cell type knees
Activating MN9 with cell type kokopelli
Activating MN9 with cell type landslide
Activating MN9 with cell type lion
Activating MN9 with cell type mandala
Activating MN9 with cell type marge
Activating MN9 with cell type meteor
Activating MN9 with cell type mime
Activating MN9 with cell type mist
Activating MN9 with cell type moor
Activating MN9 with cell type mothership
Activating MN9 with cell type mute
Activating MN9 with cell type nagini
Activating MN9 with cell type nori
Activating MN9 with cell type oink
Activating MN9 with cell type OinkT
Activating MN9 with cell type oval
Activating MN9 with cell type peacock
Activating MN9 with cell type peahen
Activating MN9 with cell type peafowl
Activating MN9 with cell type phantom
Activating MN9 with cell type planter
Activating MN9 with cell type pleco
Activating MN9 with cell type pringle
Activating MN9 with cell type pSG1
Activating MN9 with cell type puddle
Activating MN9 with cell type rattle
Activating MN9 with cell type rocket
Activating MN9 with cell type rose
Activating MN9 with cell type roundup
Activating MN9 with cell type ruby
Activating MN9 with cell type Salivary_MN13
Activating MN9 with cell type seagull
Activating MN9 with cell type shark
Activating MN9 with cell type sink_sync
Activating MN9 with cell type snake
Activating MN9 with cell type spray
Activating MN9 with cell type spirit
Activating MN9 with cell type sullivan
Activating MN9 with cell type sundrop
Activating MN9 with cell type tentacular
Activating MN9 with cell type TH_VUM
Activating MN9 with cell type tinctoria
Activating MN9 with cell type tophat
Activating MN9 with cell type TPN4
Activating MN9 with cell type trident
Activating MN9 with cell type trogon
Activating MN9 with cell type trumpet
Activating MN9 with cell type tulip
Activating MN9 with cell type tundra
Activating MN9 with cell type turner
Activating MN9 with cell type twirl
Activating MN9 with cell type usnea
Activating MN9 with cell type vice
Activating MN9 with cell type wafflecone
Activating MN9 with cell type weaver
Activating MN9 with cell type web
Activating MN9 with cell type whisker
# dfs = [
#     pd.read_csv(fig2_path_res / f'cell_type={cell_type_name}.csv').mean(axis=1)
#     for cell_type_name in types_and_ids.keys()
# ]
#
# # process results
# for freq in freqs:
#     plt.figure(figsize=(12, 6))
#     x_labels = list(types_and_ids.keys())
#     y_values = [df.loc[ids_two_mn9].mean() for df in dfs]
#
#     plt.bar(x_labels, y_values, color='skyblue')
#     plt.xlabel('Cell Type')
#     plt.ylabel('Firing of MN9')
#     plt.title(f'Relationship Between Cell Type and MN9 Firing (Frequency = {freq} Hz)')
#     plt.xticks(rotation=45, ha='right')
#     plt.tight_layout()
#     plt.show()

Figure 3#

fig3_path_res = Path('./figure_3')
os.makedirs(fig3_path_res, exist_ok=True)

Figure 3 A: Combination of Bitter and Sugar GRN#

# define inputs
# lists of all the labelled sugar-sensing gustatory receptor neurons on right hemisphere
neu_bitter = [
    720575940621778381, 720575940602353632, 720575940617094208, 720575940619197093, 720575940626287336,
    720575940618600651, 720575940627692048, 720575940630195909, 720575940646212996, 720575940610483162,
    720575940645743412, 720575940627578156, 720575940622298631, 720575940621008895, 720575940629146711,
    720575940610259370, 720575940610481370, 720575940619028208, 720575940614281266, 720575940613061118,
    720575940604027168
]

neu_ir94e = [
    720575940614211295, 720575940638218173, 720575940628832256, 720575940626016017, 720575940621375231,
    720575940612920386, 720575940614273292, 720575940628198503, 720575940626241636, 720575940619387814,
    720575940624604560, 720575940615274425, 720575940610683315, 720575940627265265, 720575940624079544,
    720575940629211607, 720575940615089369, 720575940631082124
]

id_mn9 = 720575940660219265  # left
index_mn9 = flywire_id_to_index[id_mn9]
def sugar_bitter_activate_MN9(freqs):
    @brainstate.transform.vmap2(in_axes=(0, None))
    def run_with_sugar_freq(sugar_freq, bitter_freq):
        rate = run_one_exp(
            neurons_to_excite=[{'ids': neu_sugar, 'rate': sugar_freq * u.Hz},
                               {'ids': neu_bitter, 'rate': bitter_freq * u.Hz}],
            duration=duration,
            pbar=10,
        )
        return rate[index_mn9]

    @brainstate.transform.vmap2
    def run_with_bitter_freq(bitter_freq):
        return run_with_sugar_freq(freqs, bitter_freq)

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        return run_with_bitter_freq(freqs)

    path = fig3_path_res / f'sugar_bitter_activate_MN9.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                [f'bitter {freq} Hz' for freq in freqs],
                [f'sugar {freq} Hz' for freq in freqs],
            ],
            column_axis=0
        )
        df.to_csv(path)
sugar_bitter_activate_MN9(
    np.asarray([0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200])
)
df = pd.read_csv(fig3_path_res / f'sugar_bitter_activate_MN9.csv')

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
pivot_df = df.pivot(index='dim_1', columns='dim_2', values='mean')
pivot_df = pivot_df.fillna(0)
pivot_df = pivot_df.reindex([f'sugar {f} Hz' for f in [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200]], axis=1)
pivot_df = pivot_df.reindex([f'bitter {f} Hz' for f in [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200]], axis=0)

sns.heatmap(pivot_df)
plt.show()
../_images/5d009693d542f05f775385a17a9006df2366c0e8da77446e3aa7a66494c712ae.png

Figure 3 B: Combination of Sugar and IR94e GRN#

def sugar_ir94e_activates_MN9(freqs):
    @brainstate.transform.vmap2(in_axes=(0, None))
    def run_with_sugar_freq(sugar_freq, ir94e_freq):
        rate = run_one_exp(
            neurons_to_excite=[{'ids': neu_sugar, 'rate': sugar_freq * u.Hz},
                               {'ids': neu_ir94e, 'rate': ir94e_freq * u.Hz}],
            duration=duration,
            pbar=10,
        )
        return rate[index_mn9]

    @brainstate.transform.vmap2
    def run_with_ir94e_freq(ir94e_freq):
        return run_with_sugar_freq(freqs, ir94e_freq)

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        return run_with_ir94e_freq(freqs)

    path = fig3_path_res / f'sugar_ir94e_activates_MN9.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                [f'ir94e {freq} Hz' for freq in freqs],
                [f'sugar {freq} Hz' for freq in freqs],
            ],
            column_axis=0
        )
        df.to_csv(path)
sugar_ir94e_activates_MN9(
    np.asarray([0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200])
)
df = pd.read_csv(fig3_path_res / f'sugar_ir94e_activates_MN9.csv')

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
pivot_df = df.pivot(index='dim_1', columns='dim_2', values='mean')
pivot_df = pivot_df.fillna(0)
pivot_df = pivot_df.reindex([f'sugar {f} Hz' for f in [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200]], axis=1)
pivot_df = pivot_df.reindex([f'ir94e {f} Hz' for f in [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200]], axis=0)

sns.heatmap(pivot_df)
plt.show()
../_images/73a082d92debd608b47bbcf9fb605d7704b1bfcaeba5a60dd33f014f89d24166.png

Figure 4#

fig4_path_res = Path('./figure_4')
os.makedirs(fig4_path_res, exist_ok=True)

Figure 4a: Whole Brain Activity when Activating Water GRN#

# define inputs
# list of the labelled water-sensing gustatory receptor neurons on right hemisphere
neu_water = [
    720575940612950568, 720575940631898285, 720575940606002609, 720575940612579053, 720575940622902535,
    720575940616177458, 720575940660292225, 720575940622486922, 720575940613786774, 720575940629852866,
    720575940625861168, 720575940613996959, 720575940617857694, 720575940644965399, 720575940625203504,
    720575940630553415, 720575940635172191, 720575940634796536
]
def water_GRN_whole_brain_activation(freqs):
    @brainstate.transform.vmap2
    def run_with_freq(freq_):
        return run_one_exp(
            neurons_to_excite=[
                {
                    'ids': neu_water,
                    'rate': freq_ * u.Hz
                }
            ],
            duration=duration,
            pbar=10,
        )

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def run_simulations():
        return run_with_freq(freqs)

    path = fig4_path_res / 'water_GRN_whole_brain_firing_rate.csv'
    if not os.path.exists(path):
        rates = run_simulations()

        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                [f'{freq} Hz' for freq in freqs],
                all_neuron_ids
            ],
            column_axis=0
        )
        df.to_csv(path)

        df_rate = df.loc['260 Hz'].mean(axis=1)
        df_rate_std = df.loc['260 Hz'].std(axis=1)

        # save 200 neurons for next simulation
        id_top200 = df_rate.sort_values(ascending=False).index[:200]

        print(df_rate.sort_values(ascending=False))

        with open(fig4_path_res / 'id_top200_water.pickle', 'wb') as f:
            pickle.dump(id_top200, f)
water_GRN_whole_brain_activation(
    np.asarray([20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260])
)

Figure 4b: Which Water GRN activates MN9?#

# define inputs
# flywire ID for MN9
id_mn9 = 720575940660219265

index_id_mn9 = flywire_id_to_index[id_mn9]

# identify the top 200 neurons that respond due to 200 Hz water firing (Figure 1D)
with open(fig4_path_res / 'id_top200_water.pickle', 'rb') as f:
    id_top200_water = pickle.load(f)

# neuron indices
water_top200_indices = np.asarray([flywire_id_to_index[id_] for id_ in id_top200_water])
water_top200_indices = np.expand_dims(water_top200_indices, axis=1)
def which_water_neuron_activate_MN9(freqs):
    @brainstate.transform.vmap2(in_axes=(0, None))
    def run_with_freq(freq_, indices_):
        assert indices_.ndim == 1 and indices_.size == 1
        rate = run_one_exp(
            neurons_to_excite=[{'indices': indices_, 'rate': freq_ * u.Hz}],
            duration=duration,
            pbar=10
        )
        return rate[index_id_mn9]

    def run_with_indices(indices_):
        return run_with_freq(freqs, indices_)

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        # due to the limited memory of the device, we choose using "brainstate.transform.map" instead of "brainstate.transform.vmap2"

        # return bst.transform.vmap2(run_with_indices)(bst.random.split_key(len(water_top200_indices)), water_top200_indices)
        return brainstate.transform.map(run_with_indices,
                                        water_top200_indices,
                                        batch_size=10)

    path = fig4_path_res / f'which_water_neuron_activate_NM9_firing_rate.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                id_top200_water,
                [f'{freq} Hz' for freq in freqs],
            ],
            column_axis=0
        )
        df.to_csv(path)
which_water_neuron_activate_MN9(
    np.asarray([25, 50, 75, 100, 125, 150, 175, 200])
)
df = pd.read_csv(fig4_path_res / f'which_water_neuron_activate_NM9_firing_rate.csv')

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
pivot_df = df.pivot(index='dim_1', columns='dim_2', values='mean')
pivot_df = pivot_df.fillna(0)

sns.heatmap(pivot_df)
plt.title('Which Water GRN activates NM9?')
plt.show()
../_images/5b84feae0276b936d8bd838d704fe72d78417c2cfe82a21a4c7d6a02ada62bf3.png

Figure 4c: Silencing Most Active Neurons while Activating Water GRN#

# neuron indices
water_silencing_indices = np.asarray([flywire_id_to_index[id_] for id_ in id_top200])
water_silencing_indices = np.expand_dims(water_silencing_indices, axis=1)
def slice_water_neuron_and_activate_MN9(freqs):
    @brainstate.transform.vmap2(in_axes=(0, None))
    def run_with_freq(freq_, indices_):
        rate = run_one_exp(
            neurons_to_excite=[{'ids': neu_water, 'rate': freq_ * u.Hz}],
            neuron_to_inhibit=[{'indices': indices_}],
            duration=duration,
            pbar=10
        )
        return rate[index_id_mn9]

    def run_with_indices(indices_):
        return run_with_freq(freqs, indices_)

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        # due to the limited memory of the device, we choose using "brainstate.transform.map" instead of "brainstate.transform.vmap2"

        # return bst.transform.vmap2(run_with_indices)(bst.random.split_key(len(water_silencing_indices)), water_silencing_indices)
        return brainstate.transform.map(run_with_indices, water_silencing_indices, batch_size=20)

    path = fig4_path_res / f'water_neuron_silencing_firing_rate.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                id_top200,
                [f'{freq} Hz' for freq in freqs],
            ],
            column_axis=0
        )
        df.to_csv(path)
slice_water_neuron_and_activate_MN9(
    np.asarray([160, 170, 180, 190, 200, 210, 220])
)
df = pd.read_csv(fig4_path_res / f'water_neuron_silencing_firing_rate.csv')

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
pivot_df = df.pivot(index='dim_1', columns='dim_2', values='mean')
pivot_df = pivot_df.fillna(0)

sns.heatmap(pivot_df)
plt.show()
../_images/024a032211c8626f2b35255b2b5bae10865cda2d55ee1e898ba004e770476966.png

Figure 4e: Combination of sugar and water GRN activity#

sugar_freqs = np.asarray([0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200])
water_freqs = np.asarray([0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260])
def sugar_water_activates_MN9():
    @brainstate.transform.vmap2(in_axes=(0, None))
    def run_with_sugar_freq(sugar_freq, water_freq):
        rate = run_one_exp(
            neurons_to_excite=[{'ids': neu_sugar, 'rate': sugar_freq * u.Hz},
                               {'ids': neu_water, 'rate': water_freq * u.Hz}],
            duration=duration,
            pbar=10,
        )
        return rate[index_mn9]

    @brainstate.transform.vmap2
    def run_with_water_freq(water_freq):
        return run_with_sugar_freq(sugar_freqs, water_freq)

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        return run_with_water_freq(water_freqs)

    path = fig4_path_res / f'sugar_water_activates_MN9.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                [f'water {freq} Hz' for freq in water_freqs],
                [f'sugar {freq} Hz' for freq in sugar_freqs],
            ],
            column_axis=0
        )
        df.to_csv(path)
sugar_water_activates_MN9()
df = pd.read_csv(fig4_path_res / f'sugar_water_activates_MN9.csv')
# sns.heatmap(df)

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
pivot_df = df.pivot(index='dim_1', columns='dim_2', values='mean')
pivot_df = pivot_df.fillna(0)

sns.heatmap(pivot_df)
plt.show()
../_images/2479958658108c0869c8350e34dc1456cb63057adf6f08a2b8d2621eb7a83b7a.png

Figure 5#

fig5_path_res = Path('./figure_5/data')
os.makedirs(fig5_path_res, exist_ok=True)

Figure 5b: Whole Brain Activity when Activating JON#

# define inputs
# list of the  JONs
neu_JON_CE = [
    720575940619341105, 720575940630122015, 720575940611061526, 720575940615848788, 720575940628444667,
    720575940627941431, 720575940632449619, 720575940650244342,
    720575940631866508, 720575940638681845, 720575940628978450, 720575940609522461, 720575940621442224,
    720575940602506208, 720575940629022149, 720575940627109991,
    720575940630020111, 720575940615986459, 720575940618684481, 720575940620382889, 720575940630080071,
    720575940626565455, 720575940630319671, 720575940602720940,
    720575940630564179, 720575940637632419, 720575940615809349, 720575940626042149, 720575940637054835,
    720575940602132509, 720575940614188149, 720575940616951124,
    720575940628101126, 720575940629055721, 720575940616589878, 720575940622449388, 720575940614427195,
    720575940625797617, 720575940638664437, 720575940618467195,
    720575940621729757, 720575940613971485, 720575940627585688, 720575940629650997, 720575940630059847,
    720575940608742409, 720575940614351477, 720575940633153375,
    720575940622937528, 720575940604753437, 720575940611783464, 720575940618599872, 720575940609541917,
    720575940637410869, 720575940630070343, 720575940621397417,
    720575940614035485, 720575940610018266, 720575940626307902, 720575940634634606, 720575940614060829,
    720575940624799290, 720575940641921421, 720575940623298559,
    720575940625559358, 720575940629138959, 720575940621625597, 720575940625962568, 720575940632767383,
    720575940624915230]

neu_JON_F = [
    720575940606239243, 720575940626956777, 720575940604973746, 720575940622222856, 720575940642517284,
    720575940629719404, 720575940616613022, 720575940604299454,
    720575940615473186, 720575940622217992, 720575940606800341, 720575940629267498, 720575940637366335,
    720575940624224408, 720575940609543197, 720575940633364179,
    720575940629502009, 720575940606431189, 720575940625733960, 720575940638529525, 720575940617524053,
    720575940628935564, 720575940624308355, 720575940631170346,
    720575940627704375, 720575940625885512, 720575940614929245, 720575940647493241, 720575940618888368,
    720575940625087546, 720575940606657493, 720575940617273560,
    720575940640591861, 720575940639410035, 720575940621532413, 720575940627523584, 720575940621521917,
    720575940621097398, 720575940625915338, 720575940606222428,
    720575940627868471, 720575940622179497, 720575940608297774, 720575940614026269, 720575940613012959,
    720575940628100614, 720575940606611401, 720575940628649465,
    720575940610008217, 720575940623791152, 720575940625571240, 720575940634923621, 720575940609530653,
    720575940635968745, 720575940625703434, 720575940613105311,
    720575940629386819, 720575940623077389, 720575940625763015, 720575940628359017
]

neu_JON_D_m = [
    720575940630834171, 720575940622892988, 720575940621289537, 720575940641395163, 720575940616064546,
    720575940628978409, 720575940652566177, 720575940627493096,
    720575940619085397, 720575940635545310, 720575940645728803, 720575940629141775, 720575940626557995,
    720575940631098338, 720575940639904475, 720575940635067034]

neu_JON_all = neu_JON_CE + neu_JON_F + neu_JON_D_m

# frequencies
freqs = np.asarray([20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220])
def JON_GRN_whole_brain_activation():
    @brainstate.transform.vmap2
    def run_with_freq(freq_):
        return run_one_exp(
            neurons_to_excite=[
                {
                    'ids': neu_JON_all,
                    'rate': freq_ * u.Hz
                }
            ],
            duration=duration,
            pbar=10,
        )

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def run_simulations():
        return run_with_freq(freqs)

    path = fig5_path_res / 'JON_GRN_whole_brain_firing_rate.csv'
    if not os.path.exists(path):
        rates = run_simulations()

        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                [f'{freq} Hz' for freq in freqs],
                all_neuron_ids
            ],
            column_axis=0
        )
        df.to_csv(path)

        df_rate = df.loc['220 Hz'].mean(axis=1)
        df_rate_std = df.loc['220 Hz'].std(axis=1)

        # save 200 neurons for next simulation
        id_top300 = df_rate.sort_values(ascending=False).index[:300]

        print(df_rate.sort_values(ascending=False))

        with open(fig5_path_res / 'id_top300_water.pickle', 'wb') as f:
            pickle.dump(id_top300, f)
JON_GRN_whole_brain_activation()

Figure 5c: Which JON activates DN?#

import jax.numpy as jnp
# flywire ID for MN9
id_DN1_1 = 720575940616185531
id_DN2_l = 720575940629806974

# index_DN = [flywire_id_to_index[id_DN1_1], flywire_id_to_index[id_DN2_l]]
index_DN = [flywire_id_to_index[id_DN1_1], flywire_id_to_index[id_DN2_l]]

# identify the top 200 neurons that respond due to 220 Hz JON firing (Figure 5A)
with open(fig5_path_res / 'id_top300_water.pickle', 'rb') as f:
    id_top300_JON = pickle.load(f)
# frequencies
freqs = np.asarray([25, 50, 75, 100, 125, 150, 175, 200])

# neuron indices
JON_top300_indices = np.asarray([flywire_id_to_index[id_] for id_ in id_top300_JON])
JON_top300_indices = np.expand_dims(JON_top300_indices, axis=1)
def which_JON_neuron_activate_DN():
    @brainstate.transform.vmap2(in_axes=(0, None))
    def run_with_freq(freq_, indices_):
        assert indices_.ndim == 1 and indices_.size == 1
        rate = run_one_exp(
            neurons_to_excite=[{'indices': indices_, 'rate': freq_ * u.Hz}],
            duration=duration,
            pbar=10
        )
        return rate[index_DN]

    def run_with_indices(indices_):
        return run_with_freq(freqs, indices_)

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        # due to the limited memory of the device, we choose using "brainstate.transform.map" instead of "brainstate.transform.vmap2"

        # return bst.transform.vmap2(run_with_indices)(bst.random.split_key(len(JON_top300_indices)), JON_top300_indices)
        return brainstate.transform.map(run_with_indices,
                                        JON_top300_indices,
                                        batch_size=10)

    path = fig5_path_res / f'which_JON_neuron_activate_DN.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                id_top300_JON,
                [f'{freq} Hz' for freq in freqs],
                [id_DN1_1, id_DN2_l]
            ],
            column_axis=0
        )
        df.to_csv(path)
which_JON_neuron_activate_DN()
df = pd.read_csv(fig5_path_res / f'which_JON_neuron_activate_DN.csv')

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
# pivot_df = df.pivot(index='dim_1', columns='dim_2', values='mean')
# pivot_df = df.groupby(['dim_1', 'dim_2'])['mean'].first().unstack()
pivot_df = df.pivot_table(index='dim_1', columns='dim_2', values='mean', aggfunc='mean')
pivot_df = pivot_df.fillna(0)

sns.heatmap(pivot_df)
plt.title('Which JON GRN activates NM9?')
plt.show()
../_images/78384a79b5c2dd6840fff7a8f813f27b8494ca4fd99844b0d0cbc1f362af396c.png

Figure 5d: Silencing Most Active Neurons while Activating JON GRN#

# frequencies
freqs = np.asarray([140, 150, 160, 170, 180])

# neuron indices
JON_silencing_indices = np.asarray([flywire_id_to_index[id_] for id_ in id_top300_JON])
JON_silencing_indices = np.expand_dims(JON_silencing_indices, axis=1)

brainstate.random.seed(1)
def slice_water_neuron_and_activate_MN9():
    @brainstate.transform.vmap2(in_axes=(0, None))
    def run_with_freq(freq_, indices_):
        rate = run_one_exp(
            neurons_to_excite=[{'ids': neu_water, 'rate': freq_ * u.Hz}],
            neuron_to_inhibit=[{'indices': indices_}],
            duration=duration,
            pbar=10
        )
        return rate[index_DN]
        
    

    def run_with_indices(indices_):
        return run_with_freq(freqs, indices_)

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        # due to the limited memory of the device, we choose using "brainstate.transform.map" instead of "brainstate.transform.vmap2"

        # return bst.transform.vmap2(run_with_indices)(bst.random.split_key(len(JON_silencing_indices)), JON_silencing_indices)
        return brainstate.transform.map(run_with_indices,
                                        JON_silencing_indices,
                                        batch_size=20)

    path = fig4_path_res / f'JON_neuron_silencing_firing_rate.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                id_top200,
                [f'{freq} Hz' for freq in freqs],
                [id_DN1_1, id_DN2_l]
            ],
            column_axis=0
        )
        df.to_csv(path)
slice_water_neuron_and_activate_MN9()
df = pd.read_csv(fig4_path_res / f'JON_neuron_silencing_firing_rate.csv')

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
pivot_df = df.pivot(index='dim_1', columns='dim_2', values='mean')
pivot_df = pivot_df.fillna(0)

sns.heatmap(pivot_df)
plt.show()

Figure 5g + S5: Which JON Type activates aBN?#

id_aBN1 = 720575940630907434

index_aBN1 = flywire_id_to_index[id_aBN1]

# frequencies
freqs = np.asarray([20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220])

JON-CE#

def which_JONCE_neuron_activate_aBN():
    @brainstate.transform.vmap2
    def run_with_freq(freq_):
        rate = run_one_exp(
            neurons_to_excite=[{'ids': neu_JON_CE, 'rate': freq_ * u.Hz}],
            duration=duration,
            pbar=10
        )
        return rate[index_aBN1]

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        return run_with_freq(freqs)

    path = fig5_path_res / f'JONCE_neuron_activate_aBN.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                [f'{freq} Hz' for freq in freqs],
            ],
            column_axis=0
        )
        df.to_csv(path)
which_JONCE_neuron_activate_aBN()
df = pd.read_csv(fig5_path_res / f'JONCE_neuron_activate_aBN.csv')

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
pivot_df = df.pivot(index='dim_1', values='mean')
pivot_df = pivot_df.fillna(0)

JON-F#

def which_JONF_neuron_activate_aBN():
    @brainstate.transform.vmap2
    def run_with_freq(freq_):
        rate = run_one_exp(
            neurons_to_excite=[{'ids': neu_JON_F, 'rate': freq_ * u.Hz}],
            duration=duration,
            pbar=10
        )
        return rate[index_aBN1]

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        return run_with_freq(freqs)

    path = fig5_path_res / f'JONF_neuron_activate_aBN.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                [f'{freq} Hz' for freq in freqs],
            ],
            column_axis=0
        )
        df.to_csv(path)
df = pd.read_csv(fig5_path_res / f'JONF_neuron_activate_aBN.csv')

trial_columns = [f'Trial {i}' for i in range(n_trial)]
df['mean'] = df[trial_columns].mean(axis=1)
pivot_df = df.pivot(index='dim_1', values='mean')
pivot_df = pivot_df.fillna(0)

JON-F + silencing#

three_inhibitory = [720575940636066222, 720575940609957315, 720575940624986407]
def silence_which_JONF_neuron_activate_aBN():
    @brainstate.transform.vmap2
    def run_with_freq(freq_):
        rate = run_one_exp(
            neurons_to_excite=[{'ids': neu_JON_F, 'rate': freq_ * u.Hz}],
            neuron_to_inhibit=[{'ids': three_inhibitory}],
            duration=duration,
            pbar=10
        )
        return rate[index_aBN1]

    @brainstate.transform.jit
    @brainstate.transform.vmap2(axis_size=n_trial)
    def simulating():
        return run_with_freq(freqs)

    path = fig5_path_res / f'silence_JONF_neuron_activate_aBN.csv'
    if not os.path.exists(path):
        rates = simulating()
        df = ndarray_to_dataframe(
            rates,
            [
                [f'Trial {i}' for i in range(n_trial)],
                [f'{freq} Hz' for freq in freqs],
            ],
            column_axis=0
        )
        df.to_csv(path)
silence_which_JONF_neuron_activate_aBN()