# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
"""EventProjection — single-population delta + delay event projection.
Routes delayed, weighted (pA) pre-synaptic spike events into
``post.add_delta_input``, matching how NEST current-based neurons (e.g.
``iaf_psc_alpha``) ingest spikes: the weight is a current amplitude in pA,
sign-split into excitatory/inhibitory channels inside the neuron.
"""
from __future__ import annotations
import inspect
import itertools
from typing import Callable, Optional
import brainstate
import jax
import jax.numpy as jnp
import brainunit as u
from brainpy_state._brainpy.delay import InputDelay
from brainpy_state._nest_network.connectivity import resolve_param
from brainpy_state._nest_network.nodeview import _flat_size
from brainpy_state._nest_network.projections import _DenseMatMul, _ReceptorScatter, _SparseEventMatMul
from brainpy_state._nest_network.rules import ConnRule, _OneToOne
__all__ = ['EventProjection']
# Unique delta-input keys per projection (brainstate does not auto-assign a
# usable ``self.name`` here, and multiple projections target the same post).
_DELTA_KEY_COUNTER = itertools.count()
class EventProjection(brainstate.nn.Module):
"""Delayed, weighted delta-event projection from one population segment.
Each step it reads the pre population's captured spike via ``pre_spike()``
(a callable returning the full pre-population spike/counts vector), applies
an :class:`~brainpy_state._brainpy.delay.InputDelay` (full-delay
convention, as in ``AlignPostProj``), restricts to this projection's pre
segment, maps it to the post segment (dense weighted matmul, or element-wise
for ``one_to_one``), and registers the result as a delta input on ``post``.
Parameters
----------
pre_spike : Callable[[], jax.Array]
Returns the full pre-population spike (or generator counts) vector,
shape ``(n_pre_pop,)``.
n_pre_pop : int
Size of the full pre population (the dimension ``pre_spike`` returns).
pre_local_idx : jax.Array
Local indices into the pre population selected by this projection.
post : Dynamics
Post-synaptic population; receives ``add_delta_input``.
post_local_idx : jax.Array
Local indices into the post population targeted by this projection.
rule : ConnRule
Connection rule. ``one_to_one`` triggers the element-wise path.
weight : ArrayLike or Quantity
Synaptic weight in pA (signed: positive excitatory, negative inhibitory).
delay : ArrayLike or Quantity or None
Axonal delay; ``None`` for instantaneous delivery.
as_current : bool, default False
Deposit mode. When ``False`` (default) the contribution is registered as a
delta input (``post.add_delta_input``), matching how current-based neurons
ingest spikes. When ``True`` it is registered as a **current** input
(``post.add_current_input``) instead — the route the astrocyte slow-inward
current (SIC) takes, since SIC is a pA current entering ``dV/dt`` rather than
a delta/conductance. Requires ``comm='dense'`` (a graded current cannot ride
the binarising sparse event matmul).
"""
__module__ = 'brainpy.state'
def __init__(
self,
*,
pre_spike: Callable[[], jnp.ndarray],
n_pre_pop: int,
pre_local_idx: jnp.ndarray,
post,
post_local_idx: jnp.ndarray,
rule: ConnRule,
weight,
delay=None,
comm: str = 'dense',
receptor_type=None,
channel_label=None,
as_current: bool = False,
pre_is_post: bool = False,
allow_autapses: bool = True,
allow_multapses: bool = True,
seed: Optional[int] = None,
):
super().__init__()
# Graded current deposit (SIC) must ride the dense matmul; the sparse event
# matmul binarises the presynaptic value, so reject the combination eagerly.
if as_current and comm == 'sparse':
raise ValueError(
"as_current graded-current deposit requires comm='dense'; "
"comm='sparse' binarises the presynaptic value.")
self._as_current = bool(as_current)
self._delta_key = f'event_proj_{next(_DELTA_KEY_COUNTER)}'
self.pre_spike = pre_spike
self.post = post
self.pre_local_idx = jnp.asarray(pre_local_idx)
self.post_local_idx = jnp.asarray(post_local_idx)
self._n_pre_pop = int(n_pre_pop)
self._n_post_pop = _flat_size(post)
# A multiplicity-relaying post (parrot_neuron) reads its summed delta input
# AS the spike count, so an incoming connection MUST carry the unit gate
# weight 1.0 (unitless). NEST ignores weights into a parrot; a non-unit
# weight here would silently scale the relayed multiplicity, so reject it.
if getattr(post, '_relays_multiplicity', False):
self._check_unit_gate_weight(weight)
self._one_to_one = isinstance(rule, _OneToOne)
# Labeled single-channel routing (named-channel post, e.g.
# ``iaf_cond_alpha_mc``): each ``connect(device, post, receptor_type=k)`` feeds
# exactly ONE named delta channel (``'w_ex_s'`` ...), which the model reads via
# ``sum_delta_inputs(label=...)``. Such a post exposes
# ``delta_label_for_receptor`` and has no ``n_receptors``. Resolve the label
# once and fall through to the ordinary plain comm path (a single
# ``(n_post,)`` contribution), tagging the deposit with that label.
# ``channel_label`` (an explicit string) bypasses the receptor lookup: the
# rate dual-channel split (``mult_coupling``) deposits the sign-split
# ``weight·rate`` into the post's ``'rate_ex'``/``'rate_in'`` delta channels
# directly, since those are not NEST receptors.
self._channel_label = (channel_label if channel_label is not None
else self._resolve_channel_label(post, receptor_type))
# Per-receptor routing (stacked multi-receptor post, e.g.
# ``iaf_psc_exp_multisynapse``): ``receptor_type='uniform'`` draws a port per
# edge, an int ``k`` (1-based, NEST convention) routes every edge to internal
# port ``k-1``. A resolved named channel uses the plain path instead.
self._receptor = receptor_type is not None and self._channel_label is None
self._n_receptors = int(post.n_receptors) if self._receptor else 0
# Deposit mode. Models exposing ``w_by_rec`` in their update signature
# (``iaf``/``aeif``/``gif_cond_exp_multisynapse``) are Simulator-bridged
# from one ``(N, n_receptors)`` blob; all others (the GLIF models)
# self-pull each port via ``sum_delta_inputs(label='receptor_k')``.
self._receptor_keyed = self._receptor and (
'w_by_rec' not in inspect.signature(type(post).update).parameters
)
n_pre = int(self.pre_local_idx.shape[0])
n_post = int(self.post_local_idx.shape[0])
key = jax.random.key(0 if seed is None else int(seed))
k_conn, k_w, k_rec = jax.random.split(key, 3)
if self._receptor:
if self._one_to_one:
pre_idx, post_idx, n_edges = jnp.arange(n_pre), jnp.arange(n_post), n_pre
else:
spec = rule.sample(n_pre, n_post, key=k_conn, pre_is_post=pre_is_post,
allow_autapses=allow_autapses, allow_multapses=allow_multapses)
pre_idx, post_idx, n_edges = spec.pre_idx, spec.post_idx, spec.n_edges
# ``'uniform'`` draws a port per edge; an int ``k`` (1-based, NEST
# ``receptor_type`` convention) routes every edge to internal port ``k-1``.
if isinstance(receptor_type, str):
if receptor_type != 'uniform':
raise ValueError(f"receptor_type string must be 'uniform', got {receptor_type!r}")
rec_idx = jax.random.randint(k_rec, (n_edges,), 0, self._n_receptors)
else:
k = int(receptor_type)
if not (1 <= k <= self._n_receptors):
raise ValueError(f"receptor_type {k} out of range [1, {self._n_receptors}]")
rec_idx = jnp.full((n_edges,), k - 1, dtype=jnp.int32)
w_mant, w_unit = self._edge_weight(weight, n_edges, k_w)
self.comm = _ReceptorScatter(pre_idx, post_idx, rec_idx, w_mant, w_unit,
n_post=n_post, n_receptors=self._n_receptors)
elif self._one_to_one:
# Element-wise: a scalar pA weight applied per matched element.
self._weight = weight
self.comm = None
else:
if comm not in ('dense', 'sparse'):
raise ValueError(f"comm must be 'dense' or 'sparse', got {comm!r}")
spec = rule.sample(n_pre, n_post, key=k_conn, pre_is_post=pre_is_post,
allow_autapses=allow_autapses, allow_multapses=allow_multapses)
if comm == 'dense':
if spec.n_edges == 0:
W_with_unit = jnp.zeros((n_pre, n_post))
else:
w_mant, w_unit = self._edge_weight(weight, spec.n_edges, k_w)
W = jnp.zeros((n_pre, n_post), dtype=w_mant.dtype).at[spec.pre_idx, spec.post_idx].add(w_mant)
W_with_unit = u.Quantity(W, unit=w_unit) if w_unit is not u.UNITLESS else W
self._W = brainstate.ParamState(W_with_unit)
self.comm = _DenseMatMul(self._W)
else: # sparse CSR event matmul — memory-light for large fan-out
w_mant, w_unit = self._edge_weight(weight, spec.n_edges, k_w)
self.comm = _SparseEventMatMul(spec.pre_idx, spec.post_idx, w_mant, w_unit,
n_pre=n_pre, n_post=n_post)
# Delay buffers the FULL pre-population vector (axonal granularity).
self.delay_seam = InputDelay((self._n_pre_pop,), delay) if delay is not None else None
# Retained (additive) so ``realized_edges`` can report the homogeneous
# axonal delay and ``SynapseCollection.set('delay')`` can rebuild the seam.
self._delay = delay
# Precompute (Python bool) whether this projection targets the whole
# post population, so update() can skip the scatter on the fast path.
self._post_is_full = (
n_post == self._n_post_pop
and bool(jnp.all(self.post_local_idx == jnp.arange(self._n_post_pop)))
)
@staticmethod
def _check_unit_gate_weight(weight):
"""Enforce the parrot-relay contract: an incoming weight must be the unit gate.
A ``_relays_multiplicity`` post (``parrot_neuron``) re-emits the *summed*
delta input as the spike count, which only equals the true multiplicity when
every incoming edge carries weight ``1.0`` (unitless). NEST ignores weights
on connections into a parrot; a non-unit weight here would silently scale the
relayed count, so reject it eagerly with a clear message.
Parameters
----------
weight : ArrayLike or Quantity or Callable or None
The connection weight passed to the projection. ``None`` and callable
initializers are not validated (no concrete value to check).
Raises
------
ValueError
If ``weight`` carries a physical unit, or is a concrete value other
than ``1.0``.
"""
if weight is None or callable(weight):
return # no concrete scalar/array to validate
w = weight
if isinstance(w, u.Quantity):
mant, unit = u.split_mantissa_unit(w)
if unit is not u.UNITLESS:
raise ValueError(
"connections into a parrot_neuron must use the unit gate weight "
"1.0 (unitless): the relay reads the summed input as the spike "
"count and NEST ignores weights into a parrot, but got a weight "
f"with physical units ({weight!r}).")
w = mant
if not bool(jnp.all(jnp.asarray(w) == 1.0)):
raise ValueError(
"connections into a parrot_neuron must use the unit gate weight 1.0: "
"the relay reads the summed input as the spike count and NEST ignores "
f"weights into a parrot, but got weight={weight!r}.")
@staticmethod
def _resolve_channel_label(post, receptor_type):
"""Resolve a named-channel post's ``receptor_type`` to its delta-input label.
A multi-compartment / named-channel post (``iaf_cond_alpha_mc``) routes each
connection to ONE named delta channel rather than a stacked ``n_receptors``
port. It exposes ``delta_label_for_receptor(rt) -> label`` and has no
``n_receptors``.
Parameters
----------
post : Dynamics
The postsynaptic population.
receptor_type : int or str or None
The connection's NEST receptor type.
Returns
-------
str or None
The target delta-input channel label, or ``None`` when there is no
receptor routing or the post is a stacked-``n_receptors`` model (the
existing ``_ReceptorScatter`` path handles those).
Raises
------
ValueError
Propagated from ``post.delta_label_for_receptor`` for a receptor type the
named-channel post does not accept (e.g. a current receptor on a spike
projection, or an out-of-range type).
"""
if receptor_type is None or hasattr(post, 'n_receptors'):
return None
resolver = getattr(post, 'delta_label_for_receptor', None)
if resolver is None:
return None
return resolver(receptor_type)
@staticmethod
def _edge_weight(weight, n_edges, key):
"""Resolve ``weight`` to a per-edge ``(mantissa, unit)`` pair."""
w_edge = resolve_param(weight, (n_edges,), key)
if isinstance(w_edge, u.Quantity):
return u.split_mantissa_unit(w_edge)
return jnp.asarray(w_edge), u.UNITLESS
def update(self):
x_full = self.pre_spike() # (n_pre_pop,)
if self.delay_seam is not None:
x_full = self.delay_seam.update(x_full)
x_seg = jnp.asarray(x_full)[self.pre_local_idx] # (n_pre,)
if self._receptor:
y = self.comm(x_seg) # (n_post, n_receptors)
contrib = y if self._post_is_full else self._scatter_receptor(y)
if self._receptor_keyed:
# GLIF-style self-pull: one labelled deposit per port. The label
# composes to key ``'receptor_k // <delta_key>'``, which the post's
# ``sum_delta_inputs(label='receptor_k')`` selects.
for k in range(self._n_receptors):
self.post.add_delta_input(self._delta_key, contrib[..., k], label=f'receptor_{k}')
else:
# Blob: one (n_post, n_receptors) deposit, assembled by the bridge.
self.post.add_delta_input(self._delta_key, contrib)
else:
if self._one_to_one:
y = x_seg * self._weight # (n_post,) pA
else:
y = self.comm(x_seg) # (n_post,) pA
contrib = y if self._post_is_full else self._scatter(y)
# ``as_current`` routes the graded contribution into the post's *current*
# input channel (SIC, a pA current entering dV/dt); otherwise the default
# delta channel (current-based spike ingestion).
deposit = self.post.add_current_input if self._as_current else self.post.add_delta_input
if self._channel_label is not None:
# Named/labelled channel: key composes to ``'<label> // <delta_key>'``,
# which the model selects with ``sum_{delta,current}_inputs(label=<label>)``.
deposit(self._delta_key, contrib, label=self._channel_label)
else:
deposit(self._delta_key, contrib)
def _scatter(self, y):
"""Place per-segment contributions into a full (n_post_pop,) vector."""
if isinstance(y, u.Quantity):
base = jnp.zeros(self._n_post_pop, dtype=y.mantissa.dtype)
return u.Quantity(base.at[self.post_local_idx].add(y.mantissa), unit=y.unit)
base = jnp.zeros(self._n_post_pop, dtype=y.dtype)
return base.at[self.post_local_idx].add(y)
def _scatter_receptor(self, y):
"""Place per-segment (n_post, n_receptors) contributions into the full population."""
shape = (self._n_post_pop, self._n_receptors)
if isinstance(y, u.Quantity):
base = jnp.zeros(shape, dtype=y.mantissa.dtype)
return u.Quantity(base.at[self.post_local_idx].add(y.mantissa), unit=y.unit)
base = jnp.zeros(shape, dtype=y.dtype)
return base.at[self.post_local_idx].add(y)
[docs]
def realized_edges(self):
"""Enumerate this projection's realized edges (NEST ``GetConnections`` view).
Returns
-------
ProjEdges
Population-local ``source`` / ``target`` plus live ``weight`` / ``delay``
in the canonical edge order, with guarded write-back hooks. See
:func:`~brainpy_state._nest_network.connection_introspection.event_proj_edges`.
"""
from brainpy_state._nest_network.connection_introspection import event_proj_edges
return event_proj_edges(self)