# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
"""Simulator — explicit NEST-flavored network builder and runner.
The :class:`Simulator` builds a flat module graph (populations, generators,
recorders, and delta-event projections) and runs it through a single
``brainstate.transform.for_loop``. Populations expose their per-step spikes via
a Simulator-managed :class:`_SpikeHolder` (NEST models do not persist a
``.spike`` state), so projections read the previous step's spikes — matching the
projection-before-dynamics convention. Recording is collected as a stacked JAX
array (the ``spike_recorder`` device mutates Python lists and cannot run inside
the jitted loop).
"""
from __future__ import annotations
import copy
import inspect
import itertools
from typing import Optional
import brainstate
import jax
import jax.numpy as jnp
import brainunit as u
from brainpy_state._base import Neuron
from brainpy_state._nest_device.ac_generator import ac_generator as _ac_generator
from brainpy_state._nest_device.dc_generator import dc_generator as _dc_generator
from brainpy_state._nest_device.multimeter import multimeter as _multimeter
from brainpy_state._nest_device.noise_generator import noise_generator as _noise_generator
from brainpy_state._nest_synapse.sic_connection import sic_connection as _sic_connection
from brainpy_state._nest_device.spike_recorder import spike_recorder as _spike_recorder
from brainpy_state._nest_device.step_current_generator import step_current_generator as _step_current_generator
from brainpy_state._nest_device.volume_transmitter import volume_transmitter as _volume_transmitter
from brainpy_state._nest_network.event_plastic import EventPlasticProj, VoltageCoupledPlasticProj
from brainpy_state._nest_network.event_proj import EventProjection
from brainpy_state._nest_network.nodeview import NodeView, _Segment, _flat_size
from brainpy_state._nest_network.rules import all_to_all, one_to_one, _ExplicitEdges
__all__ = ['Simulator', 'SimulationResult']
# NEST recordable name -> ordered candidate brainpy.state model State attributes.
# NEST exposes the membrane potential as ``V_m`` while the models store it on
# ``self.V``. Synaptic currents are spelled ``I_syn_ex``/``I_syn_in`` on the alpha
# family (``iaf_psc_alpha``) but ``i_syn_ex``/``i_syn_in`` on the exp family
# (``iaf_psc_exp``), so each maps to a tuple of candidate attributes tried in
# order. Recordables not listed here (``g_ex``, ``g_in``, ``w``, …) resolve by
# their own name via ``getattr`` (e.g. ``iaf_cond_alpha`` exposes ``g_ex``/``g_in``).
def _asc_sum(pop):
"""Total after-spike current: a precomputed sum state if present, else summed."""
state = getattr(pop, '_asc_sum_state', None)
if state is not None:
return state.value
return sum(s.value for s in pop._asc_states)
def _psc_sum(pop):
"""Total post-synaptic current ``I_syn``.
NEST reports ``I_syn`` as the sum of every receptor's PSC. Models expose this
differently: ``glif_psc`` keeps a single ``y2`` list of per-port PSC States
(``glif_psc.cpp``: ``S_.I_syn_ += S_.y2_[i]``), while ``glif_psc_double_alpha``
splits each port into fast/slow components and provides a ``get_I_syn()`` that
sums them. Prefer the model's own ``get_I_syn()`` when present, else sum ``y2``.
"""
get = getattr(pop, 'get_I_syn', None)
if callable(get):
return get()
return sum(s.value for s in pop.y2)
def _g_port(k):
"""Resolver for NEST per-port conductance ``g_k`` (1-indexed).
The multi-receptor models lay conductance out three different ways: a
``g_syn`` *list* of per-receptor States (``glif_cond*``), a ``g`` *list* of
per-receptor States (``gif_cond_exp_multisynapse``), or a single ``g`` State
with the receptor on the last axis (``aeif_cond_beta_multisynapse``). This
returns a ``pop -> value`` reader that handles all three.
"""
idx = k - 1
def read(pop):
g_syn = getattr(pop, 'g_syn', None)
if g_syn is not None: # glif_cond: list of States
return g_syn[idx].value
g = getattr(pop, 'g', None)
if g is None:
raise KeyError(f'g_{k}: population exposes neither g_syn nor g')
if isinstance(g, (list, tuple)): # gif: list of States
return g[idx].value
return g.value[..., idx] # aeif: single State, last axis
return read
def _adaptive_threshold(pop):
"""NEST ``V_th`` recordable for the multi-timescale adaptive-threshold family.
``mat2_psc_exp``/``amat2_psc_exp`` do not store the firing threshold; it is
*computed* as ``V_th = omega + V_th_1 + V_th_2`` (mat2_psc_exp.cpp), where
``omega`` is a parameter (absolute mV) and ``V_th_1``/``V_th_2`` are States
that jump by ``alpha_1``/``alpha_2`` on the model's own spikes and decay back.
``amat2_psc_exp`` adds the voltage-dependent component ``V_th_v``.
Models that instead expose the threshold *directly* as a ``V_th`` State (e.g.
``aeif_psc_delta_clopath``) fall through to that State, so this single alias
serves both the computed-threshold and stored-threshold conventions.
"""
omega = getattr(pop, 'omega', None)
if omega is not None and getattr(pop, 'V_th_1', None) is not None:
total = omega + pop.V_th_1.value + pop.V_th_2.value
v_th_v = getattr(pop, 'V_th_v', None) # amat2 only
if v_th_v is not None:
total = total + v_th_v.value
return total
state = getattr(pop, 'V_th', None) # clopath: threshold is a State
if state is not None and hasattr(state, 'value'):
return state.value
raise KeyError(
'V_th: population is neither a MAT model (omega + V_th_1 + V_th_2) nor '
f'exposes a V_th State on {type(pop).__name__}'
)
def _mc_comp(attr, idx):
"""Resolver for a multi-compartment recordable ``<attr>.<s|p|d>``.
``iaf_cond_alpha_mc`` stacks the three compartments on the *last axis* of a
single State (SOMA=0, PROX=1, DIST=2); NEST spells the per-compartment
recordables ``V_m.s``/``V_m.p``/``V_m.d``, ``g_ex.s``/…, ``g_in.s``/…. This
returns a ``pop -> value`` reader selecting one compartment column, preserving
the leading (neuron) axis so the analog tap can index by neuron ordinal.
"""
def read(pop):
state = getattr(pop, attr, None)
if state is None:
raise KeyError(
f'{attr!r} compartment recordable: {type(pop).__name__} has no '
f'{attr!r} stacked-compartment State'
)
return state.value[..., idx]
return read
# Map NEST recordable names to brainpy.state State attributes. Most current-based
# neurons store ``V`` (NEST ``V_m``); the exp family uses ``i_syn_ex`` where the
# alpha family uses ``I_syn_ex``; so each maps either to a tuple of candidate attrs
# tried in order, or to a callable ``pop -> value`` for derived/indexed recordables
# (per-port conductance, summed adaptation currents). Recordables not listed here
# resolve by their own name via ``getattr`` (e.g. ``iaf_cond_alpha`` exposes ``g_ex``).
_RECORDABLE_ALIAS = {
'V_m': ('V',),
# Injected current-generator input (NEST ``I`` = S_.I_ = currents_ ring buffer);
# the current-based models buffer it on the I_stim ShortTermState.
'I': ('I_stim',),
# Astrocyte slow-inward current (aeif_cond_alpha_astro stores it on I_sic).
'I_SIC': ('I_sic',),
'I_syn_ex': ('I_syn_ex', 'i_syn_ex'),
'I_syn_in': ('I_syn_in', 'i_syn_in'),
# HH gating (NEST Act_m/Inact_h/Act_n -> brainpy m/h/n).
'Act_m': ('m',),
'Inact_h': ('h',),
'Act_n': ('n',),
# GIF adaptation (NEST E_sfa / I_stc).
'E_sfa': ('_sfa_val_state',),
'I_stc': ('_stc_val_state',),
# GLIF threshold components (relative to E_L; demos add E_L test-side).
'threshold': ('_threshold_state',),
'threshold_spike': ('_threshold_spike_state',),
'threshold_voltage': ('_threshold_voltage_state',),
# Per-port conductance g_k (glif_cond g_syn list, gif g list, aeif g last-axis)
# and total after-spike current.
'g_1': _g_port(1),
'g_2': _g_port(2),
'g_3': _g_port(3),
'g_4': _g_port(4),
'ASCurrents_sum': _asc_sum,
# glif_psc total post-synaptic current: sum of per-port PSC states (y2).
'I_syn': _psc_sum,
# izhikevich recovery variable (NEST ``U_m`` -> brainpy ``U``).
'U_m': ('U',),
# MAT family adaptive threshold (computed: omega + V_th_1 + V_th_2 [+ V_th_v]);
# falls back to a directly-stored ``V_th`` State for non-MAT models (clopath).
'V_th': _adaptive_threshold,
# iaf_cond_alpha_mc per-compartment recordables (SOMA=0, PROX=1, DIST=2).
'V_m.s': _mc_comp('V', 0), 'V_m.p': _mc_comp('V', 1), 'V_m.d': _mc_comp('V', 2),
'g_ex.s': _mc_comp('g_ex', 0), 'g_ex.p': _mc_comp('g_ex', 1), 'g_ex.d': _mc_comp('g_ex', 2),
'g_in.s': _mc_comp('g_in', 0), 'g_in.p': _mc_comp('g_in', 1), 'g_in.d': _mc_comp('g_in', 2),
}
def _read_recordable(pop, name):
"""Read a NEST recordable as the model's State value (Quantity or array).
Resolves ``name`` via ``_RECORDABLE_ALIAS``: a callable entry is invoked as
``entry(pop)`` (for derived/indexed recordables), otherwise ``name`` maps to a
tuple of candidate attr spellings tried in order (falling back to the recordable
name itself).
"""
entry = _RECORDABLE_ALIAS.get(name, (name,))
if callable(entry):
return entry(pop)
for attr in entry:
state = getattr(pop, attr, None)
if state is not None:
return state.value
# Models with per-instance dynamic recordable names (cm_default: ``v_comp0``,
# ``m_Na_1``, ``g_r_AN_AMPA_1``, … keyed by compartment/receptor layout known
# only to the instance) self-resolve via ``read_recordable``.
resolver = getattr(pop, 'read_recordable', None)
if resolver is not None:
return resolver(name)
raise KeyError(
f'recordable {name!r} (tried {entry}) is not available on '
f'{type(pop).__name__}'
)
class _SpikeHolder(brainstate.nn.Module):
"""Per-population holder for the most recent captured spike/counts vector."""
__module__ = 'brainpy.state'
def __init__(self, n: int):
super().__init__()
self._n = int(n)
def init_state(self, *args, **kwargs):
self.spk = brainstate.ShortTermState(
jnp.zeros(self._n, dtype=brainstate.environ.dftype())
)
class _GeneratorSpec:
"""A deferred generator (model class + params), realised per target."""
def __init__(self, model_cls, params):
self.model_cls = model_cls
self.params = params
class _GenSegment:
"""A NodeView segment carrying a deferred generator spec (size unknown)."""
def __init__(self, spec: _GeneratorSpec):
self.spec = spec
self.population = None
self.indices = jnp.arange(0)
def _holder_reader(holder: _SpikeHolder):
return lambda: holder.spk.value
def _is_generator(model_cls) -> bool:
name = getattr(model_cls, '__name__', '')
return 'generator' in name or 'injector' in name
# Generators that inject a *current* (pA) rather than emitting spike events. These
# wire into the neuron's current-input seam (NEST current ring buffer, one-step
# delay), not the delta-event path used by spike generators.
_CURRENT_GENERATORS = (_noise_generator, _dc_generator, _step_current_generator,
_ac_generator)
def _is_current_generator(model_cls) -> bool:
return isinstance(model_cls, type) and issubclass(model_cls, _CURRENT_GENERATORS)
def _n_channels(size) -> int:
"""Flatten a ``create`` size spec to a scalar channel count."""
if isinstance(size, (tuple, list)):
n = 1
for s in size:
n *= int(s)
return n
return int(size)
def _is_len_vector(val, k: int) -> bool:
"""True if ``val`` is a length-``k`` 1-D vector (Quantity / array / sequence)."""
if isinstance(val, u.Quantity):
m = val.mantissa
return jnp.ndim(m) >= 1 and m.shape[0] == k
if isinstance(val, (list, tuple)):
return len(val) == k
return hasattr(val, 'shape') and jnp.ndim(val) >= 1 and val.shape[0] == k
def _index_channel(val, i: int, k: int):
"""Channel ``i`` of a length-``k`` vector ``val``; broadcast a scalar unchanged.
Splits a vector-valued generator parameter (e.g. ``rate=[r0, r1] * u.Hz``) or
a per-segment ``weight`` into one scalar per channel, preserving units.
"""
if not _is_len_vector(val, k):
return val
if isinstance(val, u.Quantity):
return u.maybe_decimal(val.mantissa[i] * u.get_unit(val))
return val[i]
class SimulationResult:
"""Recorded spikes and analog traces from a :meth:`Simulator.simulate` run.
Spike recorders are read with :meth:`spikes` / :meth:`n_events` / :meth:`rate`.
Analog recorders (``voltmeter`` / ``multimeter``, connected in NEST's reversed
direction) are read with :meth:`trace`, and the common time axis with
:attr:`times`.
"""
__module__ = 'brainpy.state'
def __init__(self, recordings: dict, duration, dt, *, traces=None, times=None,
weights=None):
self._rec = recordings # {id(recorder): (T, n_rec) array}
self._T = duration
self._dt = dt
self._traces = dict(traces or {}) # {f'{id(rec)}|{recordable}': (T, n) Quantity}
self._times = times # (T,) Quantity, the for_loop time axis
self._weights = dict(weights or {}) # {id(proj): (T, E) weight trajectory}
@staticmethod
def _key(node):
if isinstance(node, NodeView):
return id(node.segments[0].population)
return id(node)
@staticmethod
def _trace_key(rid, recordable):
return f'{rid}|{recordable}'
[docs]
def spikes(self, node):
"""Per-step spike matrix ``(n_steps, n_recorded)`` for a recorder/source."""
return self._rec[self._key(node)]
def n_events(self, node) -> int:
return int(jnp.sum(self._rec[self._key(node)] > 0))
[docs]
def rate(self, node) -> float:
"""Mean firing rate in spikes/second over all recorded neurons."""
spk = self._rec[self._key(node)]
n = spk.shape[1]
t_s = float(self._T.to_decimal(u.second))
return float(jnp.sum(spk > 0)) / n / t_s
[docs]
def trace(self, recorder, recordable: str = 'V_m'):
"""Analog trace ``(n_steps, n_recorded)`` for an analog recorder.
Parameters
----------
recorder : NodeView
The ``voltmeter`` / ``multimeter`` handle returned by
:meth:`Simulator.create` and connected via ``connect(recorder, pop)``.
recordable : str, optional
Recordable name (NEST vocabulary, e.g. ``'V_m'``, ``'g_ex'``).
Default is ``'V_m'``.
Returns
-------
brainunit.Quantity
``(n_steps, n_recorded)`` trace in the model state's natural unit.
Raises
------
KeyError
If ``recordable`` was not recorded by this recorder.
"""
rid = self._key(recorder)
key = self._trace_key(rid, recordable)
if key not in self._traces:
available = sorted(k.split('|', 1)[1] for k in self._traces
if k.startswith(f'{rid}|'))
raise KeyError(
f'recordable {recordable!r} was not recorded by this recorder; '
f'recorded: {available}'
)
return self._traces[key]
[docs]
def weight_trace(self, proj):
"""Per-step weight trajectory ``(n_steps, n_edges)`` for a recorded proj.
Parameters
----------
proj : EventPlasticProj
The plastic-projection handle returned by ``connect(..., synapse=spec)``
and registered via :meth:`Simulator.record_weight` before the run.
Returns
-------
brainunit.Quantity
``(n_steps, n_edges)`` weights in the synapse weight unit (pA), in CSR
(sorted-by-pre) edge order — the same order the rule kernel sees.
Raises
------
KeyError
If this projection's weight was not recorded (no
:meth:`Simulator.record_weight` before :meth:`Simulator.simulate`).
"""
rid = id(proj)
if rid not in self._weights:
raise KeyError(
"this projection's weight was not recorded; call "
'sim.record_weight(proj) before simulate()'
)
return self._weights[rid]
@property
def times(self):
"""The common time axis ``(n_steps,)`` of the run (brainunit Quantity)."""
return self._times
class Simulator(brainstate.nn.Module):
"""Explicit NEST-flavored network builder and runner.
Parameters
----------
dt : brainunit.Quantity
Simulation timestep; set into ``brainstate.environ`` at construction.
Examples
--------
.. code-block:: python
>>> import brainunit as u
>>> from brainpy.state import iaf_psc_alpha, poisson_generator, spike_recorder
>>> from brainpy.state import Simulator, all_to_all
>>> sim = Simulator(dt=0.1 * u.ms)
>>> pop = sim.create(iaf_psc_alpha, 10)
>>> noise = sim.create(poisson_generator, rate=8000. * u.Hz)
>>> rec = sim.create(spike_recorder)
>>> sim.connect(noise, pop, weight=20. * u.pA, delay=1.5 * u.ms, rule=all_to_all)
>>> sim.connect(pop, rec)
>>> res = sim.simulate(100. * u.ms)
>>> rate = res.rate(rec)
"""
__module__ = 'brainpy.state'
def __init__(self, *, dt):
super().__init__()
brainstate.environ.set(dt=dt)
self._dt = dt
self._taps = {} # id(recorder) -> (id(source), idx)
self._analog_taps = {} # id(recorder) -> (id(pop), idx, recordables)
self._weight_taps = {} # id(proj) -> EventPlasticProj (weight tap)
self._current_injectors = [] # (device, post_pop, post_idx, weight, key, comp, ncomp)
self._gap_couplers = [] # (G, D, v_reader, post_pop, key) gap-junction couplers
self._vt_nodes = [] # volume_transmitter nodes (phase-0 update)
self._proj_counter = itertools.count()
self._connections = [] # (pre_pop, post_pop, model_name, proj), registration order
self._positions = {} # id(pop) -> (n, ndim) Quantity coords (spatial populations)
# -- node creation -----------------------------------------------------
[docs]
def create(self, model_cls, size=1, *, params=None, positions=None, **kw) -> NodeView:
"""Instantiate a population/device and return a :class:`NodeView`.
Generators are deferred (realised per target at :meth:`connect`) so each
target receives an independent train, mirroring NEST fan-out.
``positions=<spatial layer>`` attaches node coordinates to a neuron
population (NEST ``Create(model, positions=spatial.grid/free(...))``). A
:func:`~brainpy_state._nest_spatial.grid` / concrete
:func:`~brainpy_state._nest_spatial.free` layer *derives* the population
size from its coordinates (the ``size`` argument is ignored); a deferred
``free`` layer (sampled from a distribution) draws ``size`` positions. The
coordinates are stored under ``id(population)`` so a
:func:`~brainpy_state._nest_spatial.spatial_pairwise_bernoulli` rule can be
bound to them at :meth:`connect`.
"""
p = dict(params or {})
p.update(kw)
coords = None
if positions is not None:
if _is_generator(model_cls):
raise ValueError(
'create(positions=...) is not supported for generators/injectors; '
'spatial layers attach to neuron populations.'
)
size, coords = self._resolve_positions(positions, size)
if _is_generator(model_cls):
k = _n_channels(size)
if k > 1:
# Multi-channel generator (Extension D2): one independent segment
# per channel, each a scalar-param spec. Vector params (e.g.
# ``rate=[r0, r1]``) are split per channel; scalars broadcast.
return NodeView([
_GenSegment(_GeneratorSpec(
model_cls, {key: _index_channel(v, i, k) for key, v in p.items()}))
for i in range(k)
])
return NodeView([_GenSegment(_GeneratorSpec(model_cls, p))])
mod = model_cls(size, **p)
setattr(self, f'_node_{id(mod)}', mod)
if coords is not None:
self._positions[id(mod)] = coords # spatial layer -> per-population coords
# Volume transmitters are driven in phase 0 (before projections) and expose
# the dopamine concentration ``n`` as State; they emit no spikes, so they
# get no _SpikeHolder (phase 2 skips them) and are registered for phase 0.
if isinstance(mod, _volume_transmitter):
self._vt_nodes.append(mod)
return NodeView.of(mod)
# Recorders are tapped, not driven: spike recorders read captured spikes,
# analog recorders (voltmeter/multimeter) read model State per step. Neither
# gets a _SpikeHolder.
if isinstance(mod, (_spike_recorder, _multimeter)):
return NodeView([_Segment(mod, jnp.arange(1))])
# Current-injecting devices (host_current_drive) are driven via the
# _current_injectors path (phase 1b), like a dc_generator: they emit pA, not
# spikes, so they get no _SpikeHolder (phase 2 would otherwise double-drive
# the schedule counter). connect() registers them as injectors.
if getattr(mod, '_injects_current', False):
return NodeView.of(mod)
holder = _SpikeHolder(_flat_size(mod))
setattr(self, f'_holder_{id(mod)}', holder)
# A neuron that integrates short-term plasticity *presynaptically* (declares
# ``_emission_attr`` -- iaf_tum_2000's released efficacy ``spike_offset``)
# also gets an emission holder. A TSODYKS connection delivers that graded
# efficacy (captured in phase 2 alongside the binary spike) rather than
# ``weight * spike``; the binary holder still serves every other connection.
if getattr(mod, '_emission_attr', None) is not None:
setattr(self, f'_emit_holder_{id(mod)}', _SpikeHolder(_flat_size(mod)))
return NodeView.of(mod)
[docs]
def get_position(self, view: NodeView):
"""Node coordinates of a spatial population/view (NEST ``GetPosition``).
Parameters
----------
view : NodeView
A single-segment view over a population created with ``create(positions=...)``.
Returns
-------
Quantity
``(n, ndim)`` coordinates (length units) in the view's node order.
Raises
------
ValueError
If the population was not created with ``positions=``.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> import brainunit as u
>>> sim = bp.Simulator(dt=0.1 * u.ms)
>>> pop = sim.create(bp.iaf_psc_alpha, positions=bp.spatial.grid([4, 3], extent=[2.0, 1.5]))
>>> [round(float(v), 2) for v in u.get_magnitude(sim.get_position(pop).to(u.um))[0]]
[-0.75, 0.5]
"""
seg = view.segments[0]
try:
coords = self._positions[id(seg.population)]
except KeyError:
raise ValueError(
'get_position requires a population created with create(positions=...).'
)
return coords[seg.indices]
def _resolve_positions(self, layer, size):
"""Resolve a spatial layer to ``(n, coords)`` for :meth:`create`.
Concrete layers (``grid`` / array-backed ``free``) derive the population
size from their coordinates; a deferred ``free`` layer draws ``size``
positions from its distribution with a reproducible per-population key.
"""
if layer.is_deferred:
n = _n_channels(size)
key = jax.random.key(self._derive_seed(None, len(self._positions)))
return n, layer.sample(n, key)
return layer.n, layer.coords
def _bind_spatial_coords(self, rule, pre_seg, post_seg):
"""Bind a spatial rule to the sliced pre/post coordinates of this connect.
Looks up each population's stored coordinates (from ``create(positions=...)``),
slices them by the segment's local indices so the rows align with the edge
index space the projection uses, and returns a coordinate-bound rule clone.
Raises if either side is a device/generator (no population) or was not created
with positions.
"""
pre_pop = pre_seg.population
post_pop = post_seg.population
if pre_pop is None or post_pop is None:
raise ValueError(
'a spatial connection rule requires neuron populations on both sides; '
'generators/devices carry no positions.'
)
try:
pre_all = self._positions[id(pre_pop)]
post_all = self._positions[id(post_pop)]
except KeyError:
raise ValueError(
'a spatial connection rule requires both populations to be created '
'with create(positions=...).'
)
return rule.with_coords(pre_all[pre_seg.indices], post_all[post_seg.indices])
# -- connection --------------------------------------------------------
[docs]
def connect(self, pre: NodeView, post: NodeView, *, rule=all_to_all,
weight=None, delay=None, comm: str = 'dense', receptor_type=None,
synapse=None, vt=None, allow_autapses: bool = True,
allow_multapses: bool = True, seed: Optional[int] = None):
"""Connect ``pre`` to ``post`` (or register a recorder tap).
``comm='sparse'`` routes the connectivity through a sparse CSR event
matmul (memory-light for large fan-out); ``'dense'`` (default) uses a
dense weight matrix. Both yield identical results for the same rule/seed.
``receptor_type='uniform'`` routes each edge to a uniformly-drawn receptor
port of a multi-receptor post population (``iaf_psc_exp_multisynapse``).
``synapse=<spec>`` builds a plastic :class:`EventPlasticProj` from a
rebuilt ``_nest`` synapse spec (``static_synapse``, the ``tsodyks*``
family, ``quantal_stp_synapse``); ``weight``/``delay`` here override the
spec's defaults. ``synapse=None`` (default) keeps the static
:class:`EventProjection` path unchanged.
``connect(dopa_pool, vt)`` (reverse direction, ``post`` a
``volume_transmitter`` view) registers each presynaptic segment as a
dopaminergic source on the transmitter and builds no projection.
``vt=<volume_transmitter view>`` binds a transmitter to a synapse spec that
reads a broadcast signal (``signal_reads``, e.g. ``stdp_dopamine_synapse``);
such a spec raises if no ``vt`` is supplied.
Analog recorders (``voltmeter`` / ``multimeter``) are connected in NEST's
reversed direction --- ``connect(recorder, pop)`` --- because the recorder
*observes* the population. This registers a per-step State tap; no
projection is built.
Returns
-------
EventProjection or EventPlasticProj or list or None
The projection handle(s) built by this call (a single handle when one
projection is built, a list for multi-segment fan-out). A plastic
handle (``synapse=spec``) can be passed to :meth:`record_weight`.
Recorder-tap connects (and current injectors) return ``None``.
"""
if len(pre.segments) == 1 and isinstance(pre.segments[0].population, _multimeter):
rec = pre.segments[0].population
if len(post.segments) != 1:
raise NotImplementedError(
'a voltmeter/multimeter records a single population segment'
)
seg = post.segments[0]
self._analog_taps[id(rec)] = (id(seg.population), seg.indices,
tuple(rec.record_from))
return None
if len(post.segments) == 1 and isinstance(post.segments[0].population, _spike_recorder):
if len(pre.segments) != 1:
raise NotImplementedError(
'recording a multi-segment view requires one recorder per segment'
)
seg = pre.segments[0]
self._taps[id(post.segments[0].population)] = (id(seg.population), seg.indices)
return None
if len(post.segments) == 1 and isinstance(post.segments[0].population, _volume_transmitter):
# reverse-direction bind: connect(dopa_pool, vt) registers each dopa
# source on the transmitter (no projection built, like a recorder tap).
vt_node = post.segments[0].population
for pre_seg in pre.segments:
self._bind_dopa_source(pre_seg, vt_node)
return None
seg_weights = self._segment_weights(weight, len(pre.segments))
projs = []
for pre_seg, w_seg in zip(pre.segments, seg_weights):
for post_seg in post.segments:
proj = self._connect_pair(pre_seg, post_seg, rule, w_seg, delay,
allow_autapses, allow_multapses, seed, comm,
receptor_type, synapse, vt)
if proj is not None:
# A ``mult_coupling`` rate pair returns its two sign-split
# (rate_ex/rate_in) projections as a list; flatten them in.
projs.extend(proj) if isinstance(proj, list) else projs.append(proj)
if not projs:
return None
return projs[0] if len(projs) == 1 else projs
@staticmethod
def _tripartite_segment(view: NodeView, role: str) -> _Segment:
"""Extract the single population segment of a tripartite role, or raise.
``tripartite_connect`` operates on single-population, single-segment views
(the astrocyte demos all use single populations; the Brunel excitatory
source is a prefix slice ``pop[:N_ex]``). A multi-segment view (e.g.
``a + b``) or a deferred generator is rejected with a clear message.
"""
if len(view.segments) != 1:
raise NotImplementedError(
f'tripartite_connect requires a single-population view for {role!r}, '
f'got {len(view.segments)} segments; concatenated views are not supported.')
seg = view.segments[0]
if isinstance(seg, _GenSegment):
raise NotImplementedError(
f'tripartite_connect {role!r} must be a created population, not a '
'deferred generator.')
return seg
[docs]
def tripartite_connect(self, pre: NodeView, post: NodeView, third: NodeView, *,
conn_spec, third_factor_conn_spec, syn_specs=None,
seed: Optional[int] = None, comm: str = 'dense',
allow_autapses: bool = True, allow_multapses: bool = True):
r"""Wire a tripartite ``pre -> post`` + astrocyte network (NEST ``TripartiteConnect``).
Samples the **primary** ``pre -> post`` edges **once** via ``conn_spec`` and
shares that single realization across three projections (NEST's tripartite
semantics):
1. **primary** ``pre -> post`` (the direct synapse, ``syn_specs['primary']``).
2. For each realized primary edge ``(pre_i -> post_j)``, the
``third_factor_conn_spec`` (:func:`~brainpy_state._nest_network.rules.third_factor_bernoulli_with_pool`)
runs a Bernoulli(``p``) trial; if it succeeds the edge is paired with one
astrocyte drawn from ``post_j``'s pool, creating
**third_in** ``pre_i -> astro`` (``syn_specs['third_in']``, delta IP3) and
**third_out** ``astro -> post_j`` (``syn_specs['third_out']``, a
:class:`~brainpy_state.sic_connection`).
Reuses the existing static :class:`~brainpy_state._nest_network.event_proj.EventProjection`
path for primary + third_in and the merged ``sic_connection`` (``as_current``)
path for third_out; no new deposit primitive. Each arm is registered for
:meth:`get_connections`.
Parameters
----------
pre, post, third : NodeView
Single-population views for the presynaptic, postsynaptic, and
astrocyte populations. ``pre`` may be a prefix slice (e.g. the Brunel
excitatory population ``neurons[:N_ex]``).
conn_spec : ConnRule
The primary one-directional rule, e.g.
:func:`~brainpy_state.pairwise_bernoulli` or
:func:`~brainpy_state.fixed_indegree`.
third_factor_conn_spec : _ThirdFactorBernoulliWithPool
The astrocyte-pool rule from
:func:`~brainpy_state.third_factor_bernoulli_with_pool`.
syn_specs : dict, optional
``{'primary': {...}, 'third_in': {...}, 'third_out': {...}}``; each value
is a dict of :meth:`connect` keyword arguments (``weight``, ``delay``,
``receptor_type``, ``synapse``). ``third_out`` typically carries
``{'synapse': sic_connection(...)}``. Missing entries default to ``{}``
(a plain unit-weight static delta).
seed : int, optional
Base PRNG seed for the whole tripartite sample (primary + pairing +
pool). The same seed reproduces the same realized edges.
comm : {'dense', 'sparse'}, default 'dense'
Communication mode for the primary / third_in arms; third_out (the
``sic_connection``) always rides ``'dense'`` (a graded current).
allow_autapses, allow_multapses : bool, default True
Passed to the primary ``conn_spec.sample``.
Returns
-------
tuple
``(primary_proj, third_in_proj, third_out_proj)``. ``third_in_proj`` and
``third_out_proj`` are ``None`` when no primary edge was paired (e.g.
``p=0``).
Examples
--------
.. code-block:: python
>>> import brainunit as u
>>> from brainpy import state as bp
>>> sim = bp.Simulator(dt=0.1 * u.ms)
>>> pre = sim.create(bp.aeif_cond_alpha_astro, 10, params={'I_e': 1000.0 * u.pA})
>>> post = sim.create(bp.aeif_cond_alpha_astro, 10, params={'I_e': 1000.0 * u.pA})
>>> astro = sim.create(bp.astrocyte_lr_1994, 5, params={'delta_IP3': 0.2})
>>> prim, tin, tout = sim.tripartite_connect(
... pre, post, astro,
... conn_spec=bp.pairwise_bernoulli(1.0),
... third_factor_conn_spec=bp.third_factor_bernoulli_with_pool(
... p=1.0, pool_size=1, pool_type='block'),
... syn_specs={'primary': {'weight': 1.0 * u.nS, 'receptor_type': 1},
... 'third_in': {'weight': 1.0},
... 'third_out': {'synapse': bp.sic_connection(weight=1.0)}},
... seed=0)
>>> tout._channel_label
'I_SIC'
"""
pre_seg = self._tripartite_segment(pre, 'pre')
post_seg = self._tripartite_segment(post, 'post')
third_seg = self._tripartite_segment(third, 'third')
n_pre = int(pre_seg.indices.shape[0])
n_post = int(post_seg.indices.shape[0])
n_third = int(third_seg.indices.shape[0])
syn_specs = syn_specs or {}
# One shared sample: the primary edges, then the derived third arms.
key = jax.random.key(0 if seed is None else int(seed))
k_primary, k_third = jax.random.split(key)
pre_is_post = pre_seg.population is post_seg.population
primary_spec = conn_spec.sample(
n_pre, n_post, key=k_primary, pre_is_post=pre_is_post,
allow_autapses=allow_autapses, allow_multapses=allow_multapses)
third_in_spec, third_out_spec = third_factor_conn_spec.sample_third(
primary_spec, n_post, n_third, key=k_third)
primary_proj = self._connect_tripartite_arm(
pre_seg, post_seg, primary_spec, syn_specs.get('primary', {}), seed, comm)
if third_in_spec.n_edges > 0:
third_in_proj = self._connect_tripartite_arm(
pre_seg, third_seg, third_in_spec, syn_specs.get('third_in', {}), seed, comm)
# third_out (astro -> post) rides the sic_connection / as_current path,
# which requires comm='dense' (a graded current).
third_out_proj = self._connect_tripartite_arm(
third_seg, post_seg, third_out_spec, syn_specs.get('third_out', {}),
seed, 'dense')
else:
third_in_proj = third_out_proj = None
return primary_proj, third_in_proj, third_out_proj
def _connect_tripartite_arm(self, pre_seg, post_seg, spec, syn_spec, seed, comm):
"""Wire one tripartite arm from an explicit (shared) edge ``spec``.
Delegates to :meth:`_connect_pair` with an
:class:`~brainpy_state._nest_network.rules._ExplicitEdges` rule so the projection
uses the *precomputed* edges instead of re-sampling, threading the arm's
``syn_spec`` (``weight`` / ``delay`` / ``receptor_type`` / ``synapse``).
"""
return self._connect_pair(
pre_seg, post_seg, _ExplicitEdges(spec),
syn_spec.get('weight'), syn_spec.get('delay'),
allow_autapses=True, allow_multapses=True, seed=seed, comm=comm,
receptor_type=syn_spec.get('receptor_type'), synapse=syn_spec.get('synapse'))
[docs]
def get_connections(self, source=None, target=None, synapse=None):
"""Enumerate realized synapses across projections (NEST ``GetConnections``).
Returns a :class:`~brainpy_state._nest_network.connection_introspection.SynapseCollection`
over the edges matching the filters, letting you read and write weights /
delays **without holding each projection handle**. The collection is a lazy
view: weights and delays are re-read from the live projections on each
:meth:`~brainpy_state._nest_network.connection_introspection.SynapseCollection.get`,
so a query made before :meth:`simulate` still reflects post-simulation
evolved plastic weights.
Parameters
----------
source : NodeView, optional
Keep only edges whose presynaptic neuron lies in this view (matched by
population identity and population-local index). ``None`` keeps all.
target : NodeView, optional
Keep only edges whose postsynaptic neuron lies in this view. ``None``
keeps all.
synapse : str, optional
Keep only projections with this synapse-model name (``'static_synapse'``
for the static event path, else the plastic spec's class name such as
``'stdp_synapse'``). ``None`` keeps all.
Returns
-------
SynapseCollection
A filtered, lazy view over the matching edges (empty if none match).
See Also
--------
connect : Build the projections this enumerates.
Examples
--------
.. code-block:: python
>>> import brainunit as u
>>> from brainpy import state as bp
>>> sim = bp.Simulator(dt=0.1 * u.ms)
>>> exc = sim.create(bp.iaf_psc_exp, 4)
>>> _ = sim.connect(exc, exc, rule=bp.all_to_all, weight=20. * u.pA,
... allow_autapses=False, comm='sparse')
>>> conns = sim.get_connections(source=exc, target=exc)
>>> len(conns)
12
>>> bool(u.math.allclose(conns.get('weight'), 20. * u.pA))
True
"""
from brainpy_state._nest_network.connection_introspection import collect_connections
return collect_connections(self._connections, source, target, synapse)
[docs]
def record_weight(self, proj):
"""Register a per-step weight tap on a plastic projection.
``proj`` is the handle returned by ``connect(..., synapse=spec)``. After
:meth:`simulate`, read the stacked ``(n_steps, n_edges)`` weight trajectory
(CSR sorted-by-pre edge order) via :meth:`SimulationResult.weight_trace`.
Mirrors the analog-recorder tap, but reads the projection's ``weight``
State rather than a population's recordable.
Returns
-------
EventPlasticProj
The same ``proj``, for chaining.
Raises
------
TypeError
If ``proj`` is not a plastic projection (only ``connect(..., synapse=)``
builds one; the static path has no plastic ``weight`` State to record).
"""
if not isinstance(proj, EventPlasticProj):
raise TypeError(
'record_weight requires a plastic projection handle from '
f'connect(..., synapse=spec); got {type(proj).__name__}'
)
self._weight_taps[id(proj)] = proj
return proj
@staticmethod
def _segment_weights(weight, n_seg: int):
"""One weight per pre-segment (Extension D2).
A ``weight`` vector whose length equals the number of pre-segments is
indexed per segment (``weight[i]`` -> segment ``i``); any other ``weight``
(scalar, or a per-edge vector for a single-segment ``all_to_all``) is
passed through unchanged to every segment.
"""
if n_seg > 1 and _is_len_vector(weight, n_seg):
return [_index_channel(weight, i, n_seg) for i in range(n_seg)]
return [weight] * n_seg
@staticmethod
def _derive_seed(base, ordinal: int) -> int:
"""Distinct, reproducible seed per realized projection/generator.
Fan-out (one ``connect`` to several post segments, or one generator to
several targets) must draw independently; ``jax.random`` derives element
``j`` from counter ``j`` regardless of array length, so sharing a base
seed would otherwise duplicate trains/connectivity across segments.
"""
b = 0 if base is None else int(base)
return (b * 1_000_003 + ordinal + 1) & 0x7FFFFFFF
@staticmethod
def _resolve_synapse(synapse, weight, delay):
"""Shallow-copy a plastic synapse spec, applying connect-level overrides."""
spec = copy.copy(synapse)
if weight is not None:
if isinstance(weight, u.Quantity):
spec.weight = weight
else:
# preserve the spec's own weight unit (pA for current synapses,
# mV for the delta-model clopath_synapse) instead of assuming pA
spec.weight = weight * u.get_unit(spec.weight)
spec.weight_unit = u.get_unit(spec.weight)
if delay is not None:
spec.delay = delay
return spec
@staticmethod
def _plastic_proj_cls(synapse):
"""Pick the plastic-projection primitive for a synapse spec.
A spec declaring a non-empty ``post_state_reads`` (e.g. ``clopath_synapse``,
per-edge post-State gather) **or** a non-empty ``signal_reads`` (e.g.
``stdp_dopamine_synapse``, broadcast modulator) needs the voltage-coupled
reader (primitive #2); every other plastic spec uses the event-driven
primitive #1.
"""
if getattr(synapse, 'post_state_reads', ()) or getattr(synapse, 'signal_reads', ()):
return VoltageCoupledPlasticProj
return EventPlasticProj
@staticmethod
def _build_signal_sources(synapse, vt):
"""Resolve a spec's ``signal_reads`` names to ``{name: (vt_module, attr)}``.
Each broadcast signal name resolves to the same-named State attribute on the
bound :class:`~brainpy_state._nest_device.volume_transmitter` (``'n'`` -> ``vt.n``).
Returns ``None`` for a spec that reads no signal (clopath); raises if a
signal-reading spec is given no transmitter.
"""
names = tuple(getattr(synapse, 'signal_reads', ()) or ())
if not names:
return None
if vt is None:
raise ValueError(
f'{type(synapse).__name__} reads broadcast signal(s) {names} and '
'requires a bound volume_transmitter; pass '
'connect(..., vt=<volume_transmitter view>).'
)
vt_mod = vt.segments[0].population if isinstance(vt, NodeView) else vt
if not isinstance(vt_mod, _volume_transmitter):
raise ValueError('vt= must be a volume_transmitter view.')
return {name: (vt_mod, name) for name in names}
def _bind_dopa_source(self, pre_seg, vt):
"""Register one presynaptic segment as a dopaminergic source on ``vt``.
A population segment binds its captured-spike holder directly; a deferred
generator segment (e.g. ``spike_generator``) is realized as a single-channel
dopa pool with its own holder (driven in phase 2), so its one-step holder lag
plays the role of NEST's ``spike_generator -> parrot -> volume_transmitter``
relay.
"""
if isinstance(pre_seg, _GenSegment):
ordinal = next(self._proj_counter)
params = dict(pre_seg.spec.params)
if 'rng_seed' in inspect.signature(pre_seg.spec.model_cls.__init__).parameters:
params['rng_seed'] = self._derive_seed(params.get('rng_seed'), ordinal)
gen = pre_seg.spec.model_cls(1, **params)
setattr(self, f'_node_{id(gen)}', gen)
holder = _SpikeHolder(1)
setattr(self, f'_holder_{id(gen)}', holder)
vt.bind_dopa(_holder_reader(holder), jnp.arange(1))
else:
pre_pop = pre_seg.population
holder = getattr(self, f'_holder_{id(pre_pop)}', None)
if holder is None:
raise ValueError(
'the dopaminergic source for a volume_transmitter must be a '
'spiking population or generator (no captured spikes found).'
)
vt.bind_dopa(_holder_reader(holder), pre_seg.indices)
def _resolve_stp_emission(self, pre_pop, post_pop, receptor_type, spike_holder,
comm='dense'):
"""Choose what a neuron->neuron static connection delivers per step.
Most neurons emit a binary spike (the spike holder -> ``weight * spike``).
A neuron that integrates a *presynaptic graded efficacy* declares
``_emission_attr`` (the State holding the per-step emission, 0 off spike)
and ``_emission_receptor`` (the one NEST receptor that carries it). Over
that receptor the connection delivers ``weight * emission`` instead of
``weight * spike``; the DEFAULT/``None`` connection -- and every other
receptor and every non-emitting model -- still delivers the binary spike
with its receptor unchanged.
A third, *receptorless* class emits **continuously**: a rate neuron
declares ``_emission_continuous=True`` and ``_emission_attr='rate'`` (or
``'phi_rate'``). Its static connection delivers ``weight * rate`` into the
post's DEFAULT delta channel every step (read back via
``sum_delta_inputs``) -- there is no NEST receptor, so this branch is
resolved first and ignores ``receptor_type``. The two receptor-gated
spike-offset emitters are:
* ``iaf_tum_2000`` -- released short-term-plasticity efficacy
``spike_offset = u * x`` over its TSODYKS receptor. The post is
single-port, so ``receptor_type`` collapses to ``None`` (a plain
delta-input delivery; the post integrates ``weight * efficacy`` as its
excitatory PSC) and the post must be the same model.
* ``iaf_bw_2001`` -- presynaptic NMDA gate increment
``spike_offset = k0 + k1 * s_NMDA_pre`` over its ``NMDA`` receptor. The
post is multi-port, so ``receptor_type`` is *preserved* and the graded
deposit is routed into the post's NMDA delta channel (via
``delta_label_for_receptor``); AMPA/GABA on other connections stay
binary and land in their own channels.
The graded emission rides the **dense** matmul (``x @ W``); the sparse
event path binarizes the presynaptic value, so ``comm='sparse'`` over an
emitting receptor is rejected.
Parameters
----------
pre_pop, post_pop : Neuron
The presynaptic and postsynaptic populations.
receptor_type : int or str or None
The connection's NEST receptor type as passed to :meth:`connect`.
spike_holder : _SpikeHolder
The presynaptic population's binary-spike holder (the default source).
comm : str, default 'dense'
The connection's communication mode. ``'sparse'`` is rejected for the
graded-emission path because it binarizes the presynaptic value.
Returns
-------
tuple
``(pre_spike_reader, effective_receptor_type)`` -- the callable the
:class:`EventProjection` reads each step, and the receptor type to
build it with (``None`` for the single-port efficacy collapse, the
original receptor for the multi-port routed path).
Raises
------
ValueError
If the graded emission would ride ``comm='sparse'`` (binarized); if a
single-port efficacy connection targets a post of a different model
(the efficacy is delivered as that model's PSC, so the post must be the
same model); or if the emission holder is absent.
"""
emit_attr = getattr(pre_pop, '_emission_attr', None)
# Continuous graded emitter (rate neurons): the connection delivers
# ``weight * emission`` (e.g. ``weight * rate``) into the post's DEFAULT
# delta channel every step -- a *receptorless* continuous coupling that
# rides the ordinary delta seam (read back via ``sum_delta_inputs``),
# not a NEST receptor. It is resolved first and independently of
# ``receptor_type`` (rate connections carry none), so it cannot collide
# with the receptor-gated STP/NMDA branches below.
if emit_attr is not None and getattr(pre_pop, '_emission_continuous', False):
emit_holder = getattr(self, f'_emit_holder_{id(pre_pop)}', None)
if emit_holder is None:
raise ValueError(
f'{type(pre_pop).__name__} declares _emission_continuous but no '
'emission holder was allocated by create().')
if comm == 'sparse':
raise ValueError(
f'a continuous rate connection from {type(pre_pop).__name__} '
f'delivers weight * {emit_attr} and must ride the dense matmul; '
'comm="sparse" binarizes the presynaptic value. Use comm="dense".')
return _holder_reader(emit_holder), None
emit_receptor = getattr(pre_pop, '_emission_receptor', None)
if emit_receptor is None:
emit_receptor = getattr(pre_pop, 'RECEPTOR_TYPES', {}).get('TSODYKS')
if emit_attr is None or receptor_type is None or receptor_type != emit_receptor:
return _holder_reader(spike_holder), receptor_type
emit_holder = getattr(self, f'_emit_holder_{id(pre_pop)}', None)
if emit_holder is None:
raise ValueError(
f'{type(pre_pop).__name__} declares _emission_attr={emit_attr!r} but '
'no emission holder was allocated by create().')
if comm == 'sparse':
raise ValueError(
f'a graded-emission connection from {type(pre_pop).__name__} over '
f'receptor_type={receptor_type} delivers weight * {emit_attr} and '
'must ride the dense matmul; comm="sparse" binarizes the '
'presynaptic value. Use comm="dense".')
# Multi-port post routes by receptor into a named delta channel: keep the
# receptor so EventProjection deposits the graded value into the right
# channel (e.g. iaf_bw_2001 NMDA).
if hasattr(post_pop, 'delta_label_for_receptor') or hasattr(post_pop, 'n_receptors'):
return _holder_reader(emit_holder), receptor_type
# Single-port post (iaf_tum_2000): the efficacy IS the PSC; collapse the
# receptor to None and require the same integrate-and-fire-with-STP model.
if not isinstance(post_pop, type(pre_pop)):
raise ValueError(
f'a graded-efficacy (receptor_type={receptor_type}) connection from '
f'{type(pre_pop).__name__} delivers presynaptic efficacy as the post '
f'PSC and requires a {type(pre_pop).__name__} post; got '
f'{type(post_pop).__name__}.')
return _holder_reader(emit_holder), None
@staticmethod
def _is_continuous_rate(pop):
"""Whether ``pop`` couples through seam-(H) continuous graded emission."""
return (getattr(pop, '_emission_continuous', False)
and getattr(pop, '_emission_attr', None) is not None)
def _check_rate_phi_homogeneity(self, pre_pop, post_pop):
r"""Guard a rate->rate connection against a φ / summation-mode mismatch.
A continuous rate connection emits the *sender's* per-step value and the
*receiver* integrates it through its own coupling path. With
``linear_summation=True`` the sender emits the raw ``rate`` and the receiver
applies its own φ to the summed input, so the two φ may differ freely; with
``linear_summation=False`` the sender emits ``φ_pre(rate)`` and the receiver
adds it *as-is* (it would otherwise have applied ``φ_post``), which is exact
only for a homogeneous φ. The summation modes must also agree --- a raw
``rate`` integrated where ``φ(rate)`` was expected (or vice versa) is silently
wrong. Both conditions are enforced here and raise at ``connect()`` time
(the user's "guard to homogeneous-φ" decision, spec §3.3).
A rate -> non-rate connection (a rate neuron driving a spiking/current
target) carries no φ contract and is left unchecked.
"""
if not self._is_continuous_rate(post_pop):
return
pre_ls = bool(getattr(pre_pop, 'linear_summation', True))
post_ls = bool(getattr(post_pop, 'linear_summation', True))
if pre_ls != post_ls:
raise ValueError(
f'rate connection {type(pre_pop).__name__} -> {type(post_pop).__name__}: '
f'linear_summation must match on both sides (pre={pre_ls}, post={post_ls}). '
'The receiver integrates the raw rate or φ(rate) the sender emits, and the '
'two summation modes are not interchangeable.')
if not pre_ls:
sig_pre = getattr(pre_pop, '_phi_signature', None)
sig_post = getattr(post_pop, '_phi_signature', None)
if sig_pre != sig_post:
raise ValueError(
f'rate connection {type(pre_pop).__name__} -> {type(post_pop).__name__} '
'with linear_summation=False requires a homogeneous input nonlinearity φ: '
'the sender emits φ(rate) and the receiver integrates it directly, which is '
f'exact only when both φ agree (got pre φ={sig_pre!r}, post φ={sig_post!r}). '
'Use linear_summation=True so the receiver applies its own φ, or match the '
'models and their gain parameters.')
@staticmethod
def _sign_split_weight(weight):
"""Split a connection weight into ``(max(W,0), min(W,0))`` by sign.
The two halves sum back to ``W``; they feed the ``mult_coupling`` excitatory
and inhibitory channels (spec §3.2). The weight must be concrete (a scalar or
array) --- a random initializer cannot be sign-split, so ``mult_coupling``
rejects callable / absent weights.
"""
if weight is None or callable(weight):
raise ValueError(
'mult_coupling rate connections sign-split the weight into the '
'rate_ex/rate_in channels and require a concrete weight (scalar or '
'array); a callable initializer or an absent weight cannot be split.')
return jnp.maximum(weight, 0.0), jnp.minimum(weight, 0.0)
def _build_rate_dual_channel(self, pre_spike, pre_pop, pre_seg, post_pop, post_seg,
rule, weight, delay, comm, seed, ordinal):
r"""Build the two sign-split labelled projections for ``mult_coupling`` (spec §3.2).
NEST's multiplicative rate coupling splits the presynaptic drive by weight
sign --- :math:`\sum_\mathrm{ex} w\,r` over the excitatory (``w>0``) edges and
:math:`\sum_\mathrm{in} w\,r` over the inhibitory (``w<0``) edges --- and scales
each partial sum by a receiver-state factor ``H_ex``/``H_in``. The split is
realised here as two ordinary rate projections reading the *same* presynaptic
emission: one carrying ``max(W,0)`` into the post's ``'rate_ex'`` delta channel,
one carrying ``min(W,0)`` into ``'rate_in'``. The receiver reads them back with
``sum_delta_inputs(label='rate_ex'|'rate_in')`` and applies ``H_ex``/``H_in``.
Both halves share one connectivity sample (the same derived seed), so a random
rule realises the *same* edge set split only by weight sign.
"""
w_ex, w_in = self._sign_split_weight(weight)
conn_seed = self._derive_seed(seed, ordinal)
proj_ords = (ordinal, next(self._proj_counter))
projs = []
for (w_part, label), p_ord in zip(((w_ex, 'rate_ex'), (w_in, 'rate_in')), proj_ords):
proj = EventProjection(
pre_spike=pre_spike, n_pre_pop=_flat_size(pre_pop),
pre_local_idx=pre_seg.indices, post=post_pop,
post_local_idx=post_seg.indices, rule=rule, weight=w_part,
delay=delay, comm=comm, channel_label=label,
pre_is_post=(pre_pop is post_pop), seed=conn_seed)
self._connections.append((pre_pop, post_pop, 'static_synapse', proj))
setattr(self, f'_proj_{p_ord}', proj)
projs.append(proj)
return projs
def _build_siegert_diffusion(self, pre_seg, post_seg, rule, weight, delay,
synapse, comm, seed, ordinal):
r"""Route a ``diffusion_connection`` as two labelled seam-(H) deposits (goal 15c).
NEST delivers a single ``DiffusionConnectionEvent`` carrying the presynaptic
rate :math:`r_j`, from which the target ``siegert_neuron`` accumulates a drift
:math:`g_\mu r_j \to \mu` and a variance :math:`g_\sigma r_j \to \sigma^2`. Here
that one event becomes two ordinary rate projections reading the *same*
presynaptic emission (the source's seam-(H) ``rate`` holder): one carrying
``drift_factor`` into the post's ``'diffusion_mu'`` delta channel, one carrying
``diffusion_factor`` into ``'diffusion_sigma2'``. The target reads them back with
``sum_delta_inputs(label='diffusion_mu' | 'diffusion_sigma2')`` -- distinct
labels so :math:`\mu` and :math:`\sigma^2` never cross-contaminate (a default,
unlabelled read would sum both). Both halves share one connectivity sample.
The source must be a continuous-rate emitter (e.g. ``siegert_neuron``); the
deposit rides the dense matmul, so ``comm='sparse'`` is rejected; and the
connection carries no delay (the one-step coupling lag is the seam holder's,
matching NEST ``min_delay=1``).
"""
if isinstance(pre_seg, _GenSegment):
raise ValueError(
'a diffusion_connection source must be a continuous-rate neuron '
'(e.g. siegert_neuron), not a generator.')
pre_pop = pre_seg.population
post_pop = post_seg.population
if not self._is_continuous_rate(pre_pop):
raise ValueError(
f'a diffusion_connection requires a continuous-rate source '
f'(_emission_continuous + _emission_attr); {type(pre_pop).__name__} '
'does not emit a rate.')
if comm == 'sparse':
raise ValueError(
'a diffusion_connection delivers drift_factor * rate / diffusion_factor '
'* rate over the dense matmul; comm="sparse" binarizes the rate. '
'Use comm="dense".')
if delay is not None:
raise ValueError('diffusion_connection has no delay.')
if weight is not None:
raise ValueError(
'diffusion_connection has no weight; use drift_factor / diffusion_factor.')
holder = getattr(self, f'_holder_{id(pre_pop)}')
pre_spike, _ = self._resolve_stp_emission(pre_pop, post_pop, None, holder, comm)
drift = float(synapse.drift_factor)
diffusion = float(synapse.diffusion_factor)
conn_seed = self._derive_seed(seed, ordinal)
proj_ords = (ordinal, next(self._proj_counter))
projs = []
for (factor, label), p_ord in zip(
((drift, 'diffusion_mu'), (diffusion, 'diffusion_sigma2')), proj_ords):
proj = EventProjection(
pre_spike=pre_spike, n_pre_pop=_flat_size(pre_pop),
pre_local_idx=pre_seg.indices, post=post_pop,
post_local_idx=post_seg.indices, rule=rule, weight=factor,
delay=None, comm=comm, channel_label=label,
pre_is_post=(pre_pop is post_pop), seed=conn_seed)
self._connections.append((pre_pop, post_pop, 'diffusion_connection', proj))
setattr(self, f'_proj_{p_ord}', proj)
projs.append(proj)
return projs
# -- gap junctions (goal 15b) ------------------------------------------
@staticmethod
def _is_gap_synapse(synapse):
"""Whether a ``connect(synapse=...)`` argument selects gap-junction coupling.
``gap_junction`` is the only synapse model that requires symmetric
connectivity (``REQUIRES_SYMMETRIC=True``); the Simulator detects it by that
marker and builds its own explicit-lag difference-deposit coupler, never
instantiating the reference model or its waveform-relaxation machinery.
"""
return bool(getattr(synapse, 'REQUIRES_SYMMETRIC', False))
@staticmethod
def _gap_conductance(weight):
"""Resolve a gap connection weight to a scalar conductance (nS mantissa).
The gap weight is the coupling conductance ``g`` (nS); the difference deposit
``g * (V_pre - V_post)`` needs one concrete scalar shared by every gap edge.
A unit ``Quantity`` is read in nS; a bare scalar is taken as nS. Per-edge /
random / callable gap weights are out of scope (like sparse gaps).
"""
if weight is None or callable(weight):
raise ValueError(
'gap_junction coupling requires a concrete scalar conductance weight '
'(nS); a callable initializer or an absent weight cannot build the '
'gap matrix G.')
w = weight
if isinstance(w, u.Quantity):
w = u.get_mantissa(w / u.nS) # nS-valued; raises if not a conductance
w = jnp.asarray(w)
if w.ndim != 0:
raise ValueError(
'gap_junction coupling uses a scalar conductance g (one value for all '
f'gap edges); got a non-scalar weight of shape {tuple(w.shape)}. '
'Per-edge gap weights are out of scope.')
return float(w)
@staticmethod
def _gap_current(G, D, v_pre, v_post):
r"""The explicit-lag gap difference current ``I_gap = G @ V_pre - D * V_post``.
``G`` (nS) is the dense symmetric gap-conductance matrix, ``D = rowsum(G)``
the per-neuron total gap conductance, and ``v_pre`` / ``v_post`` (mV) the
one-step-lagged pre / post membrane voltages. For a recurrent gap the two
voltage vectors are identical, giving ``(G - diag(D)) @ V`` --- the negated
graph Laplacian of the gap graph applied to ``V``. ``nS * mV = pA``, so the
result is the gap current in pA; at equal voltages it is exactly zero.
"""
return G @ v_pre - D * v_post
def _build_gap_coupling(self, pre_seg, post_seg, rule, weight, delay, comm,
allow_autapses, allow_multapses, seed, ordinal):
r"""Register a recurrent gap-junction coupler (the difference deposit, 15b).
Builds the dense **symmetric** gap-conductance matrix ``G`` (nS) over the
post population from the rule's realized edges (both directions materialized,
hollow diagonal), caches ``D = rowsum(G)`` and the post's V emission-holder
reader, and appends a coupler driven each step in :meth:`update`: it deposits
``I_gap = G @ V[n-1] - D * V[n-1]`` into the post's current channel
(``add_current_input`` -> ``sum_current_inputs``) under the substrate's
one-step pipeline lag (the WFR seed, cluster 15a). No waveform relaxation;
the reference ``gap_junction`` WFR class is never used.
Raises
------
ValueError
If pre and post are different populations (gap coupling is recurrent), a
delay is given (gap is instantaneous), ``comm='sparse'`` (the gap ``G@V``
is a dense matmul), the post does not emit ``V`` (``_emission_attr='V'``),
or the conductance weight is not a concrete scalar.
"""
pre_pop = pre_seg.population
post_pop = post_seg.population
if pre_pop is not post_pop:
raise ValueError(
'gap_junction coupling is recurrent (electrical coupling within one '
'population); connect a population to itself, got '
f'{type(pre_pop).__name__} -> {type(post_pop).__name__}.')
if delay is not None:
raise ValueError(
'gap_junction connections carry no delay (instantaneous electrical '
'coupling); remove the delay= argument.')
if comm != 'dense':
raise ValueError(
"gap_junction coupling rides comm='dense' (the gap G @ V matmul); "
"comm='sparse' binarizes the presynaptic voltage. Use comm='dense'.")
emit_holder = getattr(self, f'_emit_holder_{id(post_pop)}', None)
if emit_holder is None:
raise ValueError(
f'{type(post_pop).__name__} does not emit its membrane voltage for '
"gap coupling (it must declare _emission_attr='V'); no emission "
'holder was allocated by create().')
g = self._gap_conductance(weight)
n_pop = _flat_size(post_pop)
n_seg = int(post_seg.indices.shape[0])
spec = rule.sample(n_seg, n_seg,
key=jax.random.key(self._derive_seed(seed, ordinal)),
pre_is_post=True, allow_autapses=allow_autapses,
allow_multapses=allow_multapses)
# Realized edges -> a hollow, symmetric boolean adjacency over the full post
# population (both directions materialized: NEST's make_symmetric / the
# all_to_all bidirectionality), then scaled by the scalar conductance.
gpre = jnp.asarray(post_seg.indices)[spec.pre_idx]
gpost = jnp.asarray(post_seg.indices)[spec.post_idx]
A = jnp.zeros((n_pop, n_pop), dtype=bool).at[gpre, gpost].set(True)
A = (A | A.T) & ~jnp.eye(n_pop, dtype=bool)
G = A.astype(brainstate.environ.dftype()) * g # (n_pop, n_pop) nS
D = G.sum(axis=1) # rowsum (n_pop,) nS
self._gap_couplers.append((G, D, _holder_reader(emit_holder), post_pop,
f'gap_inj_{ordinal}'))
def _connect_pair(self, pre_seg, post_seg, rule, weight, delay,
allow_autapses, allow_multapses, seed, comm='dense',
receptor_type=None, synapse=None, vt=None):
ordinal = next(self._proj_counter)
post_pop = post_seg.population
# Spatial rule: bind this connect's sliced pre/post coordinates onto a pure
# rule clone before any dispatch, so every downstream path (static, plastic,
# diffusion, gap, sic) samples an identical, coordinate-bound rule.
if getattr(rule, '_is_spatial', False):
rule = self._bind_spatial_coords(rule, pre_seg, post_seg)
# Diffusion coupling (siegert mean-field): a diffusion_connection is not a
# plastic projection -- the Simulator routes it as a dual-channel seam deposit
# (drift -> mu, diffusion -> sigma^2). Dispatch before the plastic path.
if synapse is not None and getattr(synapse, '_IS_DIFFUSION', False):
return self._build_siegert_diffusion(
pre_seg, post_seg, rule, weight, delay, synapse, comm, seed, ordinal)
# Gap-junction coupling (goal 15b) is dispatched on the gap synapse model
# before the plastic/static paths: it builds its own explicit-lag difference-
# deposit coupler (no EventProjection, no rule kernel, no WFR).
if synapse is not None and self._is_gap_synapse(synapse):
self._build_gap_coupling(pre_seg, post_seg, rule, weight, delay, comm,
allow_autapses, allow_multapses, seed, ordinal)
return None
# One-way astrocyte->neuron slow-inward-current edge (NEST sic_connection):
# dispatched before the plastic path, it builds an ``as_current``
# EventProjection reading the astrocyte's emission holder.
if isinstance(synapse, _sic_connection):
return self._connect_sic(pre_seg, post_seg, rule, weight, delay,
allow_autapses, allow_multapses, seed, ordinal,
synapse, comm)
post_holder = getattr(self, f'_holder_{id(post_pop)}', None)
post_reader = _holder_reader(post_holder) if post_holder is not None else None
# voltage-coupled reader (#2) also carries broadcast signal sources (the VT n
# for stdp_dopamine_synapse); primitive #1 and a vt-less spec build neither.
proj_cls = self._plastic_proj_cls(synapse) if synapse is not None else None
plastic_extra = {}
if proj_cls is VoltageCoupledPlasticProj:
plastic_extra['signal_sources'] = self._build_signal_sources(synapse, vt)
if isinstance(pre_seg, _GenSegment):
if _is_current_generator(pre_seg.spec.model_cls):
self._wire_current_injector(pre_seg, post_seg, weight, ordinal,
receptor_type)
return
n = int(post_seg.indices.shape[0])
params = dict(pre_seg.spec.params)
if 'rng_seed' in inspect.signature(pre_seg.spec.model_cls.__init__).parameters:
params['rng_seed'] = self._derive_seed(params.get('rng_seed'), ordinal)
gen = pre_seg.spec.model_cls(n, **params)
source_pop = gen
setattr(self, f'_node_{id(gen)}', gen)
holder = _SpikeHolder(n)
setattr(self, f'_holder_{id(gen)}', holder)
if synapse is not None:
proj = proj_cls(
pre_spike=_holder_reader(holder), n_pre_pop=n,
pre_local_idx=jnp.arange(n), post=post_pop,
post_local_idx=post_seg.indices, n_post_pop=_flat_size(post_pop),
post_spike=post_reader, rule=self._resolve_synapse(synapse, weight, delay),
conn=one_to_one, seed=seed, receptor_type=receptor_type, **plastic_extra)
else:
proj = EventProjection(
pre_spike=_holder_reader(holder), n_pre_pop=n,
pre_local_idx=jnp.arange(n), post=post_pop,
post_local_idx=post_seg.indices, rule=one_to_one, weight=weight,
delay=delay, receptor_type=receptor_type, seed=seed)
else:
pre_pop = pre_seg.population
if getattr(pre_pop, '_injects_current', False):
# A non-deferred host_current_drive: register it as a current
# injector reading its host-set schedule (same ring-buffer path as a
# dc_generator), keeping the host's stable handle so the schedule can
# be rewritten between cont() chunks. No holder, no projection.
comp, ncomp = self._resolve_current_compartment(post_pop, receptor_type)
self._current_injectors.append(
(pre_pop, post_pop, post_seg.indices, weight,
f'cur_inj_{ordinal}', comp, ncomp))
return None
source_pop = pre_pop
holder = getattr(self, f'_holder_{id(pre_pop)}')
if synapse is not None:
proj = proj_cls(
pre_spike=_holder_reader(holder), n_pre_pop=_flat_size(pre_pop),
pre_local_idx=pre_seg.indices, post=post_pop,
post_local_idx=post_seg.indices, n_post_pop=_flat_size(post_pop),
post_spike=post_reader, rule=self._resolve_synapse(synapse, weight, delay),
conn=rule, pre_is_post=(pre_pop is post_pop),
allow_autapses=allow_autapses, allow_multapses=allow_multapses,
seed=self._derive_seed(seed, ordinal), receptor_type=receptor_type,
**plastic_extra)
else:
pre_spike, eff_receptor = self._resolve_stp_emission(
pre_pop, post_pop, receptor_type, holder, comm)
# Continuous rate coupling: enforce the φ / summation-mode contract,
# and split into the labelled rate_ex/rate_in channels when the
# receiver uses multiplicative coupling (spec §3.2-3.3).
if self._is_continuous_rate(pre_pop):
self._check_rate_phi_homogeneity(pre_pop, post_pop)
if getattr(post_pop, '_use_mult_coupling', False):
return self._build_rate_dual_channel(
pre_spike, pre_pop, pre_seg, post_pop, post_seg,
rule, weight, delay, comm, seed, ordinal)
proj = EventProjection(
pre_spike=pre_spike, n_pre_pop=_flat_size(pre_pop),
pre_local_idx=pre_seg.indices, post=post_pop,
post_local_idx=post_seg.indices, rule=rule, weight=weight,
delay=delay, comm=comm, receptor_type=eff_receptor,
pre_is_post=(pre_pop is post_pop),
allow_autapses=allow_autapses, allow_multapses=allow_multapses,
seed=self._derive_seed(seed, ordinal))
# Register for get_connections (NEST GetConnections analogue). The static
# event path is 'static_synapse'; a plastic path takes its spec class name.
model_name = (type(proj.rule).__name__ if isinstance(proj, EventPlasticProj)
else 'static_synapse')
self._connections.append((source_pop, post_pop, model_name, proj))
setattr(self, f'_proj_{ordinal}', proj)
return proj
def _connect_sic(self, pre_seg, post_seg, rule, weight, delay,
allow_autapses, allow_multapses, seed, ordinal, synapse, comm):
r"""Wire a one-way astrocyte->neuron slow-inward-current (SIC) connection.
The astrocyte (``astrocyte_lr_1994``) is the sender: phase 2 captures its
per-step graded ``SIC`` into the emission holder. The ``sic_connection``
deposits ``weight·SIC`` into the receiver's (``aeif_cond_alpha_astro``)
labelled ``'I_SIC'`` *current* channel through an ``as_current``
:class:`EventProjection`. The edge is one-way (a SICEvent, no back-channel),
and the NEST sender/receiver contract is enforced by
:meth:`sic_connection.check_connection`.
Timing follows the deleted host queue's ``base_offset = delay_steps - 1``:
``delay_steps=1`` (the NEST default / minimum delay) rides the substrate's
intrinsic pipeline latency with no extra :class:`InputDelay`; a larger
``delay_steps`` adds ``(delay_steps - 1)`` steps. The small residual offset
vs NEST is absorbed by ``align_steps`` in parity.
Parameters
----------
pre_seg, post_seg : _Segment
The presynaptic astrocyte and postsynaptic neuron segments.
rule : ConnRule
Connectivity rule (fan-out from astrocytes to neurons).
weight : ArrayLike or None
Connect-level weight override (unitless); ``None`` uses the spec's
``synapse.weight``.
delay : Quantity or None
Connect-level delay override; ``None`` uses ``delay_steps``.
allow_autapses, allow_multapses : bool
Passed through to the connectivity sampler.
seed : int or None
Base seed for the connectivity draw.
ordinal : int
Projection ordinal (registration key).
synapse : sic_connection
The connection spec carrying ``weight`` / ``delay_steps`` and the
sender/receiver enforcement.
comm : str
Communication mode; ``'sparse'`` is rejected (a graded current must
ride the dense matmul).
Returns
-------
EventProjection
The ``as_current`` projection delivering ``weight·SIC`` into ``'I_SIC'``.
Raises
------
ValueError
If the sender/receiver models are not the supported SIC pair, if the
source is a deferred generator, if ``comm='sparse'``, or if no emission
holder was allocated for the astrocyte.
"""
post_pop = post_seg.population
pre_name = (pre_seg.spec.model_cls.__name__ if isinstance(pre_seg, _GenSegment)
else type(pre_seg.population).__name__)
# Enforce the NEST sender/receiver contract (astrocyte_lr_1994 sends a
# SICEvent; aeif_cond_alpha_astro handles it). Raises ValueError otherwise.
synapse.check_connection(pre_name, type(post_pop).__name__)
if isinstance(pre_seg, _GenSegment):
raise ValueError(
'a sic_connection source must be an astrocyte_lr_1994 population, '
'not a deferred generator.')
if comm == 'sparse':
raise ValueError(
"a sic_connection delivers a graded current and must ride "
"comm='dense'; comm='sparse' binarises the presynaptic value.")
pre_pop = pre_seg.population
emit_holder = getattr(self, f'_emit_holder_{id(pre_pop)}', None)
if emit_holder is None:
raise ValueError(
f'{type(pre_pop).__name__} declares _emission_attr but no emission '
'holder was allocated by create().')
# Resolve weight (connect override else the spec weight) and delay (connect
# override else ``(delay_steps - 1)`` extra steps). The SIC value and weight
# are unitless; the neuron attaches pA on write-back into I_sic.
eff_weight = weight if weight is not None else synapse.weight
if delay is not None:
eff_delay = delay
elif int(getattr(synapse, 'delay_steps', 1)) > 1:
eff_delay = (int(synapse.delay_steps) - 1) * brainstate.environ.get_dt()
else:
eff_delay = None
label = getattr(pre_pop, '_emission_current_label', 'I_SIC')
proj = EventProjection(
pre_spike=_holder_reader(emit_holder), n_pre_pop=_flat_size(pre_pop),
pre_local_idx=pre_seg.indices, post=post_pop,
post_local_idx=post_seg.indices, rule=rule, weight=eff_weight,
delay=eff_delay, comm='dense', as_current=True, channel_label=label,
pre_is_post=False, allow_autapses=allow_autapses,
allow_multapses=allow_multapses, seed=self._derive_seed(seed, ordinal))
self._connections.append((pre_pop, post_pop, 'sic_connection', proj))
setattr(self, f'_proj_{ordinal}', proj)
return proj
def _wire_current_injector(self, pre_seg, post_seg, weight, ordinal, receptor_type=None):
"""Realize a current generator at the post size and register it as an injector.
Current generators (``dc_generator`` / ``step_current_generator`` /
``noise_generator`` / ``ac_generator``) inject a *current* (pA) into the
post population's current-input seam each step --- the neuron's own
``sum_current_inputs`` ring buffer (one-step delay, matching NEST's
current ring buffer) --- rather than the delta-event path used by spike
generators. No spike holder and no projection are built.
The device is realized at ``n = n_post`` so a single generator fans out
one independent channel per target (``noise_generator`` draws ``randn(n)``
each step); a per-connect derived seed keeps separate connects
independent.
``receptor_type`` selects a target compartment for a multi-compartment post
(``iaf_cond_alpha_mc``: 7=soma, 8=proximal, 9=distal); the resolved
``(comp, ncomp)`` makes the injection land in one compartment instead of
broadcasting across all of them. It is ``(None, None)`` for an ordinary
single-compartment post.
"""
post_pop = post_seg.population
comp, ncomp = self._resolve_current_compartment(post_pop, receptor_type)
n = int(post_seg.indices.shape[0])
params = dict(pre_seg.spec.params)
if 'seed' in inspect.signature(pre_seg.spec.model_cls.__init__).parameters:
params['seed'] = self._derive_seed(params.get('seed'), ordinal)
device = pre_seg.spec.model_cls(n, **params)
setattr(self, f'_node_{id(device)}', device)
key = f'cur_inj_{ordinal}'
self._current_injectors.append(
(device, post_pop, post_seg.indices, weight, key, comp, ncomp))
@staticmethod
def _resolve_current_compartment(post_pop, receptor_type):
"""Resolve a compartmental post's current ``receptor_type`` to ``(comp, ncomp)``.
A multi-compartment post (``iaf_cond_alpha_mc``) exposes
``current_compartment_for_receptor`` and ``NCOMP``; its current generators MUST
name a target compartment via ``receptor_type`` (7=soma, 8=proximal, 9=distal)
so the injected current lands in one compartment instead of broadcasting across
all of them. A non-compartmental post takes no current ``receptor_type``.
Parameters
----------
post_pop : Dynamics
The postsynaptic population.
receptor_type : int or None
The connection's NEST receptor type.
Returns
-------
tuple
``(comp, ncomp)`` --- the target compartment index and compartment count for
a compartmental post, or ``(None, None)`` for an ordinary single-compartment
post.
Raises
------
ValueError
If a compartmental post is given no ``receptor_type``, a non-compartmental
post is given one, or the receptor type is not a valid current receptor.
"""
resolver = getattr(post_pop, 'current_compartment_for_receptor', None)
if resolver is None:
if receptor_type is not None:
raise ValueError(
f'receptor_type={receptor_type} was given for current injection into '
f'{type(post_pop).__name__}, which has no current receptors.')
return None, None
if receptor_type is None:
raise ValueError(
f'{type(post_pop).__name__} requires a current receptor_type for current '
f'input (soma_curr=7, proximal_curr=8, distal_curr=9); none was given.')
comp = resolver(receptor_type) # raises for spike / out-of-range
return comp, int(post_pop.NCOMP)
@staticmethod
def _scatter_current(cur, pop, idx):
"""Place a device's ``(n,)`` current into the post population's ``(n_pop,)`` frame.
``cur`` is the generator's per-channel current (one entry per target,
in device order); it is scattered into a zero current vector over the
full population at ``idx`` so neurons outside the connection receive no
current. Works for full, partial, and reordered target views.
"""
n_pop = _flat_size(pop)
mant = u.get_mantissa(cur)
base = jnp.zeros(n_pop, dtype=mant.dtype)
return u.maybe_decimal(base.at[idx].add(mant) * u.get_unit(cur))
@staticmethod
def _scatter_current_compartment(cur, pop, idx, comp, ncomp):
"""Place a device's ``(n,)`` current into one compartment column of ``(n_pop, ncomp)``.
Like :meth:`_scatter_current`, but for a multi-compartment post: the current is
scattered into column ``comp`` of an otherwise-zero ``(n_pop, ncomp)`` frame, so
the neuron's ``sum_current_inputs`` (which reads a ``(*varshape, ncomp)`` array)
sees the current only in the targeted compartment, never broadcast across all of
them.
"""
n_pop = _flat_size(pop)
mant = u.get_mantissa(cur)
base = jnp.zeros((n_pop, ncomp), dtype=mant.dtype)
return u.maybe_decimal(base.at[idx, comp].add(mant) * u.get_unit(cur))
# -- run ---------------------------------------------------------------
def update(self, t=None):
dftype = brainstate.environ.dftype()
children = list(self.nodes(allowed_hierarchy=(1, 1)).values())
# 0) volume transmitters advance the broadcast dopamine concentration n from
# the previous step's captured dopa spikes (the substrate's one-step lag,
# matching NEST's +1 delivery stamp), so projections in phase 1 read fresh n.
for vt in self._vt_nodes:
vt.update()
# 1) projections route the previous step's spikes into delta inputs
for m in children:
if isinstance(m, (EventProjection, EventPlasticProj)):
m.update()
# 1b) current-injecting devices (dc/step/noise/ac) push this step's
# current into each post population's current-input seam. The neuron
# consumes it in step 2 via ``sum_current_inputs`` (captured into
# ``y0`` and applied on the *next* step --- a one-step ring buffer,
# matching NEST's current buffer). Non-callable inputs are popped on
# consumption, so the contribution is re-added every step.
for device, pop, idx, weight, key, comp, ncomp in self._current_injectors:
cur = device.update()
if weight is not None:
cur = cur * weight
if comp is None:
pop.add_current_input(key, self._scatter_current(cur, pop, idx))
else:
pop.add_current_input(
key, self._scatter_current_compartment(cur, pop, idx, comp, ncomp))
# 1c) gap-junction couplers deposit the explicit-lag difference current
# I_gap = G @ V[n-1] - D * V[n-1] (pA) into the post's current channel.
# V[n-1] is the previous step's voltage from the V emission holder (the
# one-step pipeline lag = NEST's use_wfr=False seed, cluster 15a); the
# deposit rides the same current ring buffer as the device injectors.
# For a recurrent gap V_pre and V_post are the same population's V.
for G, D, v_reader, pop, key in self._gap_couplers:
v = jnp.asarray(v_reader()) # (n_pop,) mV, one step lagged
pop.add_current_input(key, self._gap_current(G, D, v, v) * u.pA)
# 2) drive neurons/generators and capture their output into holders
for m in children:
if isinstance(m, (EventProjection, EventPlasticProj, _SpikeHolder)):
continue
holder = getattr(self, f'_holder_{id(m)}', None)
if holder is None:
continue # recorders / untracked devices have no holder
if (isinstance(m, Neuron) and hasattr(m, 'n_receptors')
and 'w_by_rec' in inspect.signature(type(m).update).parameters):
# Multi-receptor neuron: gather the per-port delta input and drive
# the model's JIT-safe ``w_by_rec`` path (its no-arg seam is numpy).
# ``receptor_input_unit`` scales the gathered mantissa: pA for
# current-based (iaf), nS for conductance-based (aeif/gif) models.
runit = getattr(m, 'receptor_input_unit', u.pA)
init = u.math.zeros(m.varshape + (int(m.n_receptors),)) * runit
out = m.update(w_by_rec=u.get_mantissa(m.sum_delta_inputs(init) / runit))
else:
out = m.update()
if isinstance(m, Neuron) and not getattr(m, '_relays_multiplicity', False):
val = (jnp.asarray(u.get_mantissa(out)) >= 0.5).astype(dftype)
else:
# Generators and multiplicity-relaying neurons (parrot_neuron)
# keep their raw per-step count instead of a binarised spike.
val = jnp.asarray(u.get_mantissa(out), dtype=dftype)
holder.spk.value = val
# Presynaptic STP: capture this step's released efficacy (graded, 0 off
# spike) into the emission holder, time-aligned with the binary spike so a
# TSODYKS connection delivers it with the same latency as a plain spike.
emit_attr = getattr(m, '_emission_attr', None)
if emit_attr is not None:
emit = getattr(self, f'_emit_holder_{id(m)}', None)
if emit is not None:
emit.spk.value = jnp.asarray(
u.get_mantissa(getattr(m, emit_attr).value), dtype=dftype)
[docs]
def simulate(self, duration, *, dt=None) -> SimulationResult:
"""Run for ``duration`` from a freshly initialised state.
Spike recorders are stacked as ``(n_steps, n_recorded)`` arrays; analog
recorders (``voltmeter`` / ``multimeter``) tap their target population's
State each step (after the update) into ``(n_steps, n_recorded)`` traces
keyed by recordable. The run's time axis is exposed as ``res.times``.
This re-initialises ALL state (``init_all_states``) and runs one window from
``t = 0``. To continue a rollout across windows *without* re-initialising —
interleaving host-side work between chunks (read recordings, rewrite
``host_drive`` schedules, overwrite static weights) — use :meth:`cont`.
"""
if dt is None:
dt = self._dt
self.reset_rollout(dt=dt)
return self._run_window(duration, dt)
[docs]
def reset_rollout(self, *, dt=None):
"""Initialise all state and start a fresh persistent rollout at ``t = 0``.
Calls ``brainstate.nn.init_all_states(self)`` and zeroes the accumulated
rollout clock (``_base_t`` / ``_base_i``). :meth:`simulate` calls this for
you; call it explicitly before a :meth:`cont` loop to (re)start cleanly.
"""
if dt is None:
dt = self._dt
brainstate.nn.init_all_states(self)
self._base_t = 0.0 * u.get_unit(dt)
self._base_i = 0
self._rollout_ready = True
[docs]
def cont(self, duration, *, dt=None) -> SimulationResult:
"""Continue the rollout for ``duration`` WITHOUT re-initialising state.
Unlike :meth:`simulate`, state persists across calls (biological time
accumulates), so a host loop can interleave Python work between chunks —
read this window's recordings, rewrite a ``host_drive`` schedule, or
overwrite static weights via ``get_connections(...).set('weight', ...)`` —
while the compiled per-chunk ``for_loop`` is reused (no recompile as long as
only State *contents* change). Lazily initialises on the first call (or
after :meth:`reset_rollout`). Each call returns a :class:`SimulationResult`
for that window, whose ``times`` are absolute (offset by the accumulated
rollout clock).
"""
if dt is None:
dt = self._dt
if not getattr(self, '_rollout_ready', False):
self.reset_rollout(dt=dt)
return self._run_window(duration, dt)
def _run_window(self, duration, dt) -> SimulationResult:
"""Run one ``for_loop`` window of ``duration`` over the rollout clock and
stack the spike / analog / weight taps, then advance the clock.
Shared by :meth:`simulate` (after a re-init, ``_base_t == 0``) and
:meth:`cont` (no re-init): ``times`` / ``indices`` are offset by the
accumulated ``_base_t`` / ``_base_i`` so the device counters, time axis, and
``environ`` step index continue across windows.
"""
import brainstate.transform as transform
local = u.math.arange(0.0 * u.get_unit(dt), duration, dt)
times = self._base_t + local
indices = self._base_i + u.math.arange(local.size)
taps = dict(self._taps)
analog = dict(self._analog_taps)
weight_taps = dict(self._weight_taps)
def step(t, i):
with brainstate.environ.context(t=t, i=i):
self.update(t)
spk_out = {rid: getattr(self, f'_holder_{sid}').spk.value[idx]
for rid, (sid, idx) in taps.items()}
ana_out = {}
for rid, (sid, idx, names) in analog.items():
pop = getattr(self, f'_node_{sid}')
for name in names:
ana_out[SimulationResult._trace_key(rid, name)] = _read_recordable(pop, name)[idx]
# weight taps read the projection's (post-update) weight State
w_out = {rid: proj.weight.value for rid, proj in weight_taps.items()}
return spk_out, ana_out, w_out
stacked_spk, stacked_ana, stacked_w = transform.for_loop(step, times, indices)
recordings = {rid: jnp.asarray(stacked_spk[rid]) for rid in taps}
traces = {key: stacked_ana[key] for key in stacked_ana}
weights = {}
for rid, proj in weight_taps.items():
arr = jnp.asarray(stacked_w[rid]) # (T, E) mantissa
unit = getattr(proj, '_w_unit', u.UNITLESS)
weights[rid] = arr if unit is u.UNITLESS else u.maybe_decimal(arr * unit)
self._base_t = self._base_t + local.size * dt
self._base_i = self._base_i + int(local.size)
return SimulationResult(recordings, duration, dt, traces=traces, times=times,
weights=weights)