# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
"""EventPlasticProj — JAX-native, event-driven plastic projection substrate.
This is the first of the three typed plasticity primitives. It owns the *compute*: a CSR edge layout (reusing the
:class:`~brainpy_state._nest_network.projections._SparseEventMatMul` convention),
an :class:`~brainpy_state._brainpy.delay.InputDelay` axonal delay seam,
rule-declared per-edge / per-neuron :class:`brainstate.State` allocation, and
the ``brainevent.CSR`` event matmul that delivers weighted spikes into
``post.add_delta_input``.
The *fidelity* lives in the synapse spec (``_nest/<model>_synapse.py``): a
frozen :class:`PlasticSynapse` exposes the NEST parameter spec plus a pure,
vectorized rule kernel ``update(state, ctx) -> (state, w_eff)``. The substrate
hands the kernel a :class:`KernelContext` each step and gates delivery by the
actual presynaptic spikes.
The whole hot path is ``jit`` / ``vmap`` / ``for_loop`` safe: every State has a
static shape fixed at trace time, there is no Python-list growth, no
data-dependent host control flow, and all elementwise math is ``jnp``.
"""
from __future__ import annotations
from typing import Callable, NamedTuple, Optional, Protocol
import brainevent
import brainstate
import jax
import jax.numpy as jnp
import numpy as np
import brainunit as u
from brainpy_state._brainpy.delay import InputDelay
from brainpy_state._nest_network.event_proj import EventProjection
__all__ = ['EventPlasticProj', 'VoltageCoupledPlasticProj', 'KernelContext', 'PlasticSynapse']
def _trace_spec(attr, mode='all_to_all'):
"""Normalize a rule's per-side trace-tau declaration.
Returns ``None`` (no trace), ``('single', tau_ms, mode)`` for a scalar
``Quantity`` (cluster-01 contract: 1-D per-neuron State), or
``('multi', (tau0_ms, ...), mode)`` for a tuple/list of taus (``stdp_triplet``,
later clopath: one per-neuron column per tau). ``mode`` is the per-side
pairing mode ``'all_to_all'`` (decay-then-add, the default) or ``'nearest'``
(reset-to-1 on spike; cluster-05 nearest-neighbour STDP).
"""
if attr is None:
return None
if isinstance(attr, (tuple, list)):
return ('multi', tuple(float(u.Quantity(t).to_decimal(u.ms)) for t in attr), mode)
return ('single', float(u.Quantity(attr).to_decimal(u.ms)), mode)
class KernelContext(NamedTuple):
"""Per-step inputs the substrate hands a rule kernel.
All per-edge arrays are length ``E = n_edges`` in CSR (sorted-by-pre) edge
order; scalars are 0-d. Every value is a unit-free mantissa — the substrate
re-attaches the pA unit on delivery.
Attributes
----------
pre_spike : jax.Array
``(E,)`` — did this edge's (delayed) presynaptic neuron fire this step.
post_spike : jax.Array
``(E,)`` — did this edge's postsynaptic neuron fire this step (zeros if
the projection has no post-spike reader).
pre_trace : jax.Array
``(E,)`` presynaptic trace gathered per edge (zeros if the rule declares
no ``pre_trace_tau``). When the rule declares a *tuple* of taus this is
the first column of ``pre_traces``.
post_trace : jax.Array
``(E,)`` postsynaptic trace gathered per edge (zeros if none); first
column of ``post_traces`` for multi-trace rules.
t_now : jax.Array
Scalar current simulation time, ms mantissa.
dt : jax.Array
Scalar timestep, ms mantissa.
key : jax.Array
Per-step PRNG key (used by stochastic rules; deterministic rules ignore).
pre_traces : jax.Array
``(E, k)`` all presynaptic traces gathered per edge, one column per tau
when ``pre_trace_tau`` is a tuple (``stdp_triplet`` / clopath seam).
``(E, 1)`` for a scalar tau, ``(E, 0)`` for none. ``pre_trace`` aliases
column 0.
post_traces : jax.Array
``(E, k)`` all postsynaptic traces gathered per edge (see ``pre_traces``).
post_states : dict
``{name: (E,)}`` per-edge gathers of the post-neuron analog State variables
named by the rule's ``post_state_reads`` (the primitive-#2
:class:`VoltageCoupledPlasticProj` reader; e.g. ``'V'`` / ``'u_bar_plus'`` /
``'u_bar_minus'`` for ``clopath_synapse``). Each value is a unit-stripped
mantissa in the State's stored unit, in CSR (sorted-by-pre) edge order.
``None`` for the base :class:`EventPlasticProj` (primitive #1 reads no post
State).
signals : dict
``{name: scalar}`` broadcast modulatory signals read from a bound third-party
node (the :class:`~brainpy_state._nest_device.volume_transmitter`) and shared by
**every edge** of the projection (cluster-08; e.g. ``'n'``, the dopamine
concentration, for ``stdp_dopamine_synapse``). Each value is a unit-stripped
scalar mantissa broadcast against the ``(E,)`` per-edge arrays — a superset
of :attr:`post_states` (1->E broadcast vs N->E gather). ``None`` for any
projection that reads no broadcast signal (primitives #1 and post-only #2).
"""
pre_spike: jax.Array
post_spike: jax.Array
pre_trace: jax.Array
post_trace: jax.Array
t_now: jax.Array
dt: jax.Array
key: jax.Array
pre_traces: jax.Array = None
post_traces: jax.Array = None
post_states: dict = None
signals: dict = None
class PlasticSynapse(Protocol):
"""Structural protocol every rebuilt ``_nest`` synapse spec satisfies.
The concrete specs live in ``_nest/<model>_synapse.py``; this substrate only
relies on the attributes and the two methods below.
"""
weight: object # per-edge init (pA); scalar if homogeneous
delay: object # homogeneous axonal delay (Quantity) or None
is_homogeneous_weight: bool # 'weight' State is a shared 0-d scalar
stochastic: bool # needs ctx.key
pre_trace_tau: object # None | Quantity | tuple[Quantity,...] (multi-trace)
post_trace_tau: object # None | Quantity | tuple[Quantity,...] (multi-trace)
weight_unit: object # pA
# optional (default 'all_to_all'); 'nearest' resets the trace to 1 on each
# spike instead of accumulating (cluster-05 nearest-neighbour STDP):
# pre_trace_mode, post_trace_mode
# optional (default False); when True the substrate honours a sub-dt delay by
# delivering at the integer floor delay and splitting the post amplitude
# across the two bracketing grid steps (cluster-06 cont_delay_synapse):
# fractional_delay
def edge_state_init(self) -> dict: ...
def update(self, state: dict, ctx: KernelContext) -> tuple[dict, jax.Array]: ...
class _StaticTestRule:
"""Test-only constant-weight rule (constant ``w_eff``; no aux, no traces).
Lives here so the substrate is testable before the real ``_nest`` specs are
rebuilt. Production code uses the specs in ``_nest/``.
"""
is_homogeneous_weight = False
stochastic = False
pre_trace_tau = None
post_trace_tau = None
weight_unit = u.pA
def __init__(self, weight=1.0, delay=None):
self.weight = weight
self.delay = delay
def edge_state_init(self) -> dict:
return {}
def update(self, state, ctx):
return state, state['weight']
class EventPlasticProj(brainstate.nn.Module):
"""Event-driven plastic projection from one population segment.
Each step it reads the pre population's captured spike vector via
``pre_spike()``, applies the axonal :class:`InputDelay`, restricts to this
projection's pre/post segments, maintains any rule-declared per-neuron
traces, calls the synapse rule kernel to obtain the per-edge effective
weight, and delivers it through a ``brainevent.CSR`` event matmul into
``post.add_delta_input`` (summing multapses, scattering into sub-segments).
Parameters
----------
pre_spike : Callable[[], jax.Array]
Returns the full pre-population spike vector, shape ``(n_pre_pop,)``.
n_pre_pop : int
Size of the full pre population.
pre_local_idx : array_like
Local indices into the pre population selected by this projection.
post : Dynamics or None
Post-synaptic population (receives ``add_delta_input``). ``None`` is
permitted for substrate-only tests that do not deliver.
post_local_idx : array_like
Local indices into the post population targeted by this projection.
rule : PlasticSynapse
The synapse spec carrying the pure ``update`` kernel and the parameter
declarations (weight, delay, traces, stochastic, homogeneity).
conn : ConnRule, optional
Connectivity sampler; used when explicit ``pre_idx``/``post_idx`` edges
are not supplied.
pre_idx, post_idx : array_like, optional
Explicit segment-local edges (skip sampling). Both or neither.
n_post_pop : int, optional
Size of the full post population (defaults to ``len(post_local_idx)``).
post_spike : Callable[[], jax.Array], optional
Returns the full post-population spike vector (STDP seam).
pre_is_post, allow_autapses, allow_multapses : bool
Forwarded to the connectivity sampler.
seed : int, optional
Connectivity sampling seed.
delta_key : str, optional
Unique ``add_delta_input`` key (defaults to one derived from ``id``).
receptor_type : int, optional
Named-channel routing for a multi-compartment / named-channel post (a
post exposing ``delta_label_for_receptor`` and no ``n_receptors``, e.g.
``pp_cond_exp_mc_urbanczik``). Resolved once to a delta-input channel
label (mirroring the static :class:`EventProjection` seam); each per-step
deposit is then tagged with that label so the deposit reaches the correct
compartment/sign. ``None`` (default) delivers to the unlabeled key.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp, brainstate, brainunit as u
>>> from brainpy_state._nest_network.event_plastic import EventPlasticProj, _StaticTestRule
>>> class _Sink:
... def add_delta_input(self, key, val): self.last = val
>>> sink = _Sink()
>>> brainstate.environ.set(dt=0.1 * u.ms)
>>> proj = EventPlasticProj(
... pre_spike=lambda: jnp.array([1., 0.]), n_pre_pop=2,
... pre_local_idx=jnp.arange(2), post=sink,
... post_local_idx=jnp.arange(2), n_post_pop=2,
... pre_idx=jnp.array([0, 1]), post_idx=jnp.array([0, 1]),
... rule=_StaticTestRule(weight=jnp.array([3., 4.]) * u.pA))
>>> _ = brainstate.nn.init_all_states(proj)
>>> with brainstate.environ.context(t=0.1 * u.ms, i=1):
... _ = proj.update()
>>> u.get_mantissa(sink.last).tolist()
[3.0, 0.0]
"""
__module__ = 'brainpy.state'
def __init__(
self,
*,
pre_spike: Callable[[], jax.Array],
n_pre_pop: int,
pre_local_idx,
post,
post_local_idx,
rule: PlasticSynapse,
conn=None,
pre_idx=None,
post_idx=None,
n_post_pop: Optional[int] = None,
post_spike: Optional[Callable[[], jax.Array]] = None,
pre_is_post: bool = False,
allow_autapses: bool = True,
allow_multapses: bool = True,
seed: Optional[int] = None,
delta_key: Optional[str] = None,
receptor_type=None,
):
super().__init__()
self.pre_spike = pre_spike
self.post_spike = post_spike
self.post = post
self.rule = rule
self.pre_local_idx = jnp.asarray(pre_local_idx)
self.post_local_idx = jnp.asarray(post_local_idx)
self._n_pre_pop = int(n_pre_pop)
self._n_post_pop = int(n_post_pop if n_post_pop is not None
else self.post_local_idx.shape[0])
self._delta_key = delta_key or f'event_plastic_{id(self)}'
# Named-channel single-label routing (mc post, e.g. the Urbanczik
# neuron): resolve receptor_type -> delta channel label once, then tag
# every deposit so the plastic weight reaches the right compartment.
# ``None`` for an unrouted or stacked-n_receptors post (delivers
# unlabeled, the historical behavior).
self._channel_label = EventProjection._resolve_channel_label(post, receptor_type)
# Runtime-rng seed for stochastic rules. ``init_state`` keys the per-step
# ``rng`` from this so the seed survives ``init_all_states`` (which the
# Simulator runs inside ``simulate``); ``None`` keeps the historical key(0)
# default. Shares the integer with connectivity sampling (mirrors NEST's
# single ``rng_seed``); the two consume independent key instances.
self._rng_seed = seed
n_pre = int(self.pre_local_idx.shape[0])
n_post = int(self.post_local_idx.shape[0])
# -- edges (segment-local), sorted once by pre into CSR order ---------
if pre_idx is None or post_idx is None:
if conn is None:
raise ValueError('EventPlasticProj needs either explicit edges '
'(pre_idx, post_idx) or a connectivity rule (conn).')
key = jax.random.key(0 if seed is None else int(seed))
spec = conn.sample(n_pre, n_post, key=key, pre_is_post=pre_is_post,
allow_autapses=allow_autapses, allow_multapses=allow_multapses)
pre_idx, post_idx = spec.pre_idx, spec.post_idx
pre_np = np.asarray(pre_idx)
post_np = np.asarray(post_idx)
order = np.argsort(pre_np, kind='stable') # group edges by pre row
self._pre_idx = jnp.asarray(pre_np[order])
self._post_idx = jnp.asarray(post_np[order])
self._indices = jnp.asarray(post_np[order])
self._indptr = jnp.asarray(
np.concatenate([[0], np.cumsum(np.bincount(pre_np, minlength=n_pre))]))
self._shape = (n_pre, n_post)
self._E = int(pre_np.shape[0])
# -- weight init mantissa + unit (kept off the hot path) --------------
w = rule.weight
if isinstance(w, u.Quantity):
self._w_unit = w.unit
w_m = jnp.asarray(w.mantissa)
else:
self._w_unit = getattr(rule, 'weight_unit', u.UNITLESS)
w_m = jnp.asarray(w)
self._w_init = (jnp.reshape(w_m, ()) if rule.is_homogeneous_weight
else jnp.broadcast_to(w_m, (self._E,)))
# -- axonal delay seam (identity when delay is None) ------------------
# A rule may opt into a *sub-dt* delay via ``fractional_delay = True``
# (cont_delay_synapse): the substrate then delivers the binary event at
# the integer floor delay (so ``BinaryArray`` is not fed a fractional
# vector) and splits the post amplitude across the two bracketing grid
# steps with a 1-step output carry (built in ``init_state``, where ``dt``
# is known). Default-off: every other rule keeps the unchanged path.
self._fractional_delay = bool(getattr(rule, 'fractional_delay', False))
self.delay_seam = (InputDelay((self._n_pre_pop,), rule.delay)
if (rule.delay is not None and not self._fractional_delay)
else None)
# -- pre-computed (Python bool) full-post fast path -------------------
self._post_is_full = (
n_post == self._n_post_pop
and bool(jnp.all(self.post_local_idx == jnp.arange(self._n_post_pop)))
)
# -- per-side trace specs (None | single | multi) + mode, fixed at trace time
self._pre_trace_spec = _trace_spec(
rule.pre_trace_tau, getattr(rule, 'pre_trace_mode', 'all_to_all'))
self._post_trace_spec = _trace_spec(
rule.post_trace_tau, getattr(rule, 'post_trace_mode', 'all_to_all'))
[docs]
def init_state(self, *args, **kwargs):
dftype = brainstate.environ.dftype()
self.weight = brainstate.ParamState(jnp.asarray(self._w_init, dtype=dftype))
self.aux = {
name: brainstate.HiddenState(jnp.full((self._E,), float(v), dtype=dftype))
for name, v in self.rule.edge_state_init().items()
}
self.pre_trace = self._alloc_trace(self._pre_trace_spec, self._n_pre_pop, dftype)
self.post_trace = self._alloc_trace(self._post_trace_spec, self._n_post_pop, dftype)
self.rng = (brainstate.State(jax.random.key(
0 if self._rng_seed is None else int(self._rng_seed)))
if self.rule.stochastic else None)
# -- sub-dt (fractional) delay seam (default-off) ---------------------
# Decompose the homogeneous delay into an integer floor ``k_lo`` and a
# fraction ``frac = d/dt - k_lo`` (NEST cont_delay's ``delay_steps`` /
# ``delay_offset_``). Deliver the binary event at ``k_lo`` steps (clean
# floor frame for ``BinaryArray``) and FIR-split the post amplitude
# ``[1-frac, frac]`` across the two bracketing grid steps. ``frac == 0``
# (integer delay) -> no carry -> byte-identical to a plain grid delay.
self.delay_carry = None
self._delay_frac = 0.0
if self._fractional_delay and self.rule.delay is not None:
dt = brainstate.environ.get_dt()
steps = (float(u.Quantity(self.rule.delay).to_decimal(u.ms))
/ float(u.Quantity(dt).to_decimal(u.ms)))
k_lo = int(np.floor(steps + 1e-9))
frac = float(steps - k_lo)
self._delay_frac = 0.0 if frac < 1e-9 else frac
if k_lo >= 1:
self.delay_seam = InputDelay((self._n_pre_pop,), k_lo * dt)
brainstate.nn.init_all_states(self.delay_seam)
if self._delay_frac > 0.0:
self.delay_carry = brainstate.HiddenState(
jnp.zeros((self._shape[1],), dtype=dftype))
# -- helpers -----------------------------------------------------------
@staticmethod
def _alloc_trace(spec, n, dftype):
"""Per-neuron trace State: 1-D ``(n,)`` for single tau, ``(n, k)`` for multi."""
if spec is None:
return None
kind, taus, _mode = spec
shape = (n,) if kind == 'single' else (n, len(taus))
return brainstate.HiddenState(jnp.zeros(shape, dtype=dftype))
@staticmethod
def _advance_trace(state_obj, spec, x_full, dt, gather, E):
"""Decay a per-neuron trace, gather per edge; store per the trace mode.
Returns ``(trace_edge (E,), traces_edge (E, k))``. ``spec`` ``None`` ->
zeros. Single tau keeps the 1-D State + ``(E,)`` gather (``(E, 1)`` matrix
view); a tuple of taus decays each column by its own tau.
Both modes GATHER the same ``decayed + spike`` (so the cluster-04 kernel
exclusion ``k = ctx.trace - ctx.spike`` reduces to the strictly-prior
value, which for ``nearest`` is NEST's "second-latest preceding partner"
on a coinciding step); they differ only in what is STORED for the next
step -- ``all_to_all`` accumulates ``decayed + spike``, ``nearest`` resets
to 1 on the spike (``where(spike, 1, decayed)``).
"""
if spec is None:
return jnp.zeros((E,)), jnp.zeros((E, 0))
kind, taus, mode = spec
if kind == 'single':
decayed = state_obj.value * jnp.exp(-dt / taus)
accumulated = decayed + x_full # kernel-facing
state_obj.value = (jnp.where(x_full > 0, 1.0, decayed)
if mode == 'nearest' else accumulated)
e = accumulated[gather] # (E,)
return e, e[:, None] # (E, 1)
taus_arr = jnp.asarray(taus) # (k,)
decayed = state_obj.value * jnp.exp(-dt / taus_arr)
accumulated = decayed + x_full[:, None]
state_obj.value = (jnp.where(x_full[:, None] > 0, 1.0, decayed)
if mode == 'nearest' else accumulated)
g = accumulated[gather] # (E, k)
return g[:, 0], g
@staticmethod
def _t_dt_ms():
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
t_ms = u.Quantity(t).to_decimal(u.ms) if isinstance(t, u.Quantity) else jnp.asarray(t)
dt_ms = u.Quantity(dt).to_decimal(u.ms)
return jnp.asarray(t_ms), jnp.asarray(dt_ms)
def _scatter(self, y):
"""Place per-segment contributions into a full ``(n_post_pop,)`` vector."""
base = jnp.zeros(self._n_post_pop, dtype=y.dtype)
return base.at[self.post_local_idx].add(y)
def _gather_post_states(self):
"""Post-neuron analog-State reads for the kernel (``None`` for primitive #1).
The base :class:`EventPlasticProj` reads no post State; the override in
:class:`VoltageCoupledPlasticProj` returns the ``{name: (E,)}`` per-edge
gather declared by ``rule.post_state_reads``.
"""
return None
def _gather_signals(self):
"""Broadcast modulatory-signal reads for the kernel (``None`` by default).
The base :class:`EventPlasticProj` reads no broadcast signal; the override
in :class:`VoltageCoupledPlasticProj` returns the ``{name: scalar}`` dict
declared by ``rule.signal_reads`` and bound via ``signal_sources`` (the
cluster-08 ``volume_transmitter`` seam). Read-only.
"""
return None
# -- step --------------------------------------------------------------
def update(self):
x_full = jnp.asarray(self.pre_spike())
if self.delay_seam is not None:
x_full = jnp.asarray(self.delay_seam.update(x_full))
pre_seg = x_full[self.pre_local_idx]
pre_fired = pre_seg[self._pre_idx]
if self.post_spike is not None:
post_full = jnp.asarray(self.post_spike())
post_seg = post_full[self.post_local_idx]
post_fired = post_seg[self._post_idx]
else:
post_full = None
post_fired = jnp.zeros((self._E,))
t_now, dt = self._t_dt_ms()
# per-neuron traces (decay-then-add, gather post-update). Single tau ->
# 1-D State + (E,) edge trace; a tuple of taus -> (N, k) State + (E, k).
pre_trace_edge, pre_traces_edge = self._advance_trace(
self.pre_trace, self._pre_trace_spec, x_full, dt,
self.pre_local_idx[self._pre_idx], self._E)
if self.post_trace is not None and post_full is not None:
post_trace_edge, post_traces_edge = self._advance_trace(
self.post_trace, self._post_trace_spec, post_full, dt,
self.post_local_idx[self._post_idx], self._E)
else:
post_trace_edge = jnp.zeros((self._E,))
post_traces_edge = jnp.zeros((self._E, 0))
key = jax.random.key(0)
if self.rng is not None:
key, sub = jax.random.split(self.rng.value)
self.rng.value = key
key = sub
state = {'weight': self.weight.value, **{k: v.value for k, v in self.aux.items()}}
ctx = KernelContext(pre_fired, post_fired, pre_trace_edge, post_trace_edge,
t_now, dt, key, pre_traces_edge, post_traces_edge,
self._gather_post_states(), self._gather_signals())
new_state, w_eff = self.rule.update(state, ctx)
self.weight.value = new_state['weight']
for k, v in self.aux.items():
v.value = new_state[k]
# deliver via CSR event-matmul (gated by the actual pre spikes)
w_eff = jnp.broadcast_to(jnp.asarray(w_eff), (self._E,))
csr = brainevent.CSR((w_eff, self._indices, self._indptr), shape=self._shape)
y = jnp.asarray(brainevent.BinaryArray(pre_seg) @ csr) # (n_post_seg,)
if self.delay_carry is not None:
# sub-dt delay: deliver (1-frac) of this floor-delayed amplitude now,
# carry frac to the next grid step (1-step FIR -> exact first moment).
frac = self._delay_frac
y, self.delay_carry.value = (1.0 - frac) * y + self.delay_carry.value, frac * y
contrib = y if self._post_is_full else self._scatter(y)
if self._w_unit is not u.UNITLESS:
contrib = u.Quantity(contrib, unit=self._w_unit)
if self.post is not None:
if self._channel_label is not None:
# Named-channel post: tag the deposit so it lands on the resolved
# compartment/sign channel (read back via sum_delta_inputs(label=)).
self.post.add_delta_input(self._delta_key, contrib, label=self._channel_label)
else:
self.post.add_delta_input(self._delta_key, contrib)
return contrib
[docs]
def realized_edges(self):
"""Enumerate this plastic projection's realized edges (``GetConnections`` view).
Reads the live (post-simulation evolved) ``weight`` State when allocated,
else the pre-simulation init. A weight-evolving rule exposes no weight
write-back. See
:func:`~brainpy_state._nest_network.connection_introspection.plastic_proj_edges`.
Returns
-------
ProjEdges
Population-local ``source`` / ``target`` plus live ``weight`` / ``delay``
in the canonical edge order.
"""
from brainpy_state._nest_network.connection_introspection import plastic_proj_edges
return plastic_proj_edges(self)
class VoltageCoupledPlasticProj(EventPlasticProj):
"""Voltage-coupled plastic projection — primitive #2 of the typed family.
A superset of :class:`EventPlasticProj` that adds a **post-neuron analog-state
reader**. The rule declares a tuple of post-neuron ``State`` attribute names in
``post_state_reads`` (e.g. ``('u_bar_minus', 'u_bar_plus', 'V')`` for
``clopath_synapse``); each step the projection gathers those per-post-neuron
State columns **per edge** — in CSR (sorted-by-pre) edge order, exactly the
post-trace gather (``post_local_idx[post_idx]``) — and hands them to the kernel
as :attr:`KernelContext.post_states`, a ``{name: (E,)}`` dict of unit-stripped
mantissas (in each State's stored unit). This samples a continuous post-neuron
quantity (membrane / filtered voltage) that a spike-driven trace cannot
reconstruct.
Everything else — CSR delivery, axonal delay, rule-declared per-neuron traces
(``x_bar`` via ``pre_trace_tau``), the weight-recording / ``_stdp_drive`` seams —
is inherited unchanged. The post population module supplied as ``post`` is the
read source (``getattr(post, name).value``); it must be present and the rule
must declare a non-empty ``post_state_reads``.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp, brainstate, brainunit as u
>>> from brainpy_state._nest_network.event_plastic import (
... VoltageCoupledPlasticProj, _StaticTestRule)
>>> class _Post:
... def __init__(self): self.V = type('S', (), {'value': jnp.array([3.]) * u.mV})()
... def add_delta_input(self, key, val): self.last = val
>>> class _Read(_StaticTestRule):
... post_state_reads = ('V',)
... def update(self, state, ctx):
... return state, state['weight'] + ctx.post_states['V']
>>> brainstate.environ.set(dt=0.1 * u.ms)
>>> post = _Post()
>>> proj = VoltageCoupledPlasticProj(
... pre_spike=lambda: jnp.array([1.]), n_pre_pop=1, pre_local_idx=jnp.arange(1),
... post=post, post_local_idx=jnp.arange(1), n_post_pop=1,
... pre_idx=jnp.array([0]), post_idx=jnp.array([0]),
... rule=_Read(weight=jnp.array([1.]) * u.pA))
>>> _ = brainstate.nn.init_all_states(proj)
>>> with brainstate.environ.context(t=0.1 * u.ms, i=1):
... _ = proj.update()
>>> u.get_mantissa(post.last).tolist()
[4.0]
"""
__module__ = 'brainpy.state'
def __init__(self, *args, signal_sources=None, **kwargs):
super().__init__(*args, **kwargs)
self._post_state_reads = tuple(getattr(self.rule, 'post_state_reads', ()) or ())
self._signal_reads = tuple(getattr(self.rule, 'signal_reads', ()) or ())
self._signal_sources = dict(signal_sources or {})
if not self._post_state_reads and not self._signal_reads:
raise ValueError(
'VoltageCoupledPlasticProj requires the rule to declare a non-empty '
"'post_state_reads' (post-neuron State sampled per edge) or "
"'signal_reads' (a broadcast scalar read from a bound node); use "
'EventPlasticProj for a projection that reads neither.'
)
if self.post is None:
raise ValueError(
'VoltageCoupledPlasticProj needs a post population to read State from '
'(post=None is only valid for the base EventPlasticProj).'
)
missing = [name for name in self._signal_reads if name not in self._signal_sources]
if missing:
raise ValueError(
f'VoltageCoupledPlasticProj: rule declares signal_reads={self._signal_reads} '
f'but no source was bound for {missing}; pass '
'signal_sources={name: (node, attr)} (the Simulator wires this from '
'connect(..., vt=...)).'
)
# per-edge population-local post index (only the post-state gather needs it)
self._post_gather = (self.post_local_idx[self._post_idx]
if self._post_state_reads else None)
def _gather_post_states(self):
if not self._post_state_reads:
return None
return {name: u.get_mantissa(getattr(self.post, name).value)[self._post_gather]
for name in self._post_state_reads}
def _gather_signals(self):
if not self._signal_sources:
return None
return {name: u.get_mantissa(getattr(node, attr).value)
for name, (node, attr) in self._signal_sources.items()}