Source code for brainpy_state._nest_network.connection_introspection

# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
"""Connection enumeration + introspection (NEST ``GetConnections`` analogue).

This module is a thin, **additive** convenience layer over the projections the
:class:`~brainpy_state._nest_network.simulator.Simulator` already builds — it does not
change how any projection stores its edges. Each projection family exposes a
``realized_edges()`` accessor (in ``_event_proj.py`` / ``_event_plastic.py``)
returning a :class:`ProjEdges` view of its realized synapses; this module wraps a
filtered set of those across projections in a :class:`SynapseCollection`, the
object ``Simulator.get_connections`` returns.

Edge-order contract
-------------------
Within one projection, edges are enumerated grouped by **ascending
population-local source index** (stable within a group — the projection's native
CSR / row-major order). Across projections, results are concatenated in
**registration order** (the order ``connect`` built them). NEST orders
``GetConnections`` globally by ``(source, target, …)``; for a single-projection
query this coincides, and the two ported demos scatter into a ``W[src, trg] += w``
matrix (order-insensitive), so any cross-projection ordering difference is
immaterial (documented divergence).

Node indices are **population-local** — ``brainpy.state`` has no global node-id
space, so ``source`` / ``target`` are indices within the source / target
population (the matrix-building demos index their weight matrices with them
directly, without NEST's ``min(node_id)`` subtraction).
"""
from __future__ import annotations

import dataclasses
from typing import Callable, Optional

import brainstate
import jax.numpy as jnp
import numpy as np
import brainunit as u

from brainpy_state._brainpy.delay import InputDelay

__all__ = ['ProjEdges', 'SynapseCollection']

_STATIC_PLASTIC_MODELS = ('static_synapse', 'static_synapse_hom_w')


@dataclasses.dataclass(frozen=True)
class ProjEdges:
    """A single projection's realized edges, in the documented canonical order.

    Returned by each projection's ``realized_edges()``. ``source`` / ``target``
    are population-local int arrays; ``weight`` / ``delay`` are brainunit
    Quantities read **live** (post-simulation evolved weights for a plastic
    projection). ``write_weight`` / ``write_delay`` are the guarded write-back
    hooks used by :meth:`SynapseCollection.set` (``None`` when the field is not
    settable).

    Parameters
    ----------
    source, target : numpy.ndarray
        Population-local pre / post index per edge, shape ``(E,)``.
    weight : brainunit.Quantity
        Per-edge weight, shape ``(E,)`` (a homogeneous projection broadcasts its
        shared scalar to ``E``).
    delay : brainunit.Quantity
        Per-edge axonal delay, shape ``(E,)`` (homogeneous broadcast; ``0 ms``
        when the projection has no delay).
    is_homogeneous_weight : bool
        ``True`` when the projection stores a single shared weight (a per-edge
        ``set`` is refused).
    is_plastic : bool
        ``True`` for an ``EventPlasticProj`` / ``VoltageCoupledPlasticProj``.
    model_name : str
        Synapse-model name (``'static_synapse'`` for the static event path, else
        the spec class name, e.g. ``'stdp_synapse'``).
    write_weight : Callable or None
        ``write_weight(local_edge_idx, mantissa_values)`` writes weight mantissas
        (in the projection's weight unit) at the given canonical edge positions;
        ``None`` if the projection refuses weight writes (a weight-evolving
        plastic rule).
    write_delay : Callable or None
        ``write_delay(quantity)`` sets the homogeneous axonal delay (rounded to
        the ``dt`` grid); ``None`` if unsettable.
    """
    source: np.ndarray
    target: np.ndarray
    weight: u.Quantity
    delay: u.Quantity
    is_homogeneous_weight: bool
    is_plastic: bool
    model_name: str
    write_weight: Optional[Callable] = None
    write_delay: Optional[Callable] = None


def canonical_order(source) -> np.ndarray:
    """Stable argsort grouping edges by ascending population-local source.

    Stable so that within a source group the projection's native storage order
    (CSR column order for the event/plastic CSR paths, row-major for dense) is
    preserved — the documented edge-order contract.
    """
    return np.argsort(np.asarray(source), kind='stable')


# ---------------------------------------------------------------------------
# Unit / delay helpers
# ---------------------------------------------------------------------------

def _split_weight(w):
    """Split a weight into ``(mantissa ndarray-like, unit)`` (``UNITLESS`` if plain)."""
    if isinstance(w, u.Quantity):
        return jnp.asarray(w.mantissa), w.unit
    return jnp.asarray(w), u.UNITLESS


def _as_quantity(mantissa, unit):
    """Re-attach ``unit`` to a mantissa (a plain array when ``unit`` is UNITLESS)."""
    return u.Quantity(mantissa, unit=unit) if unit is not u.UNITLESS else mantissa


def _broadcast_delay(delay, E: int) -> u.Quantity:
    """Broadcast a homogeneous axonal delay to an ``(E,)`` ms Quantity (``0 ms`` if none)."""
    if delay is None:
        d_ms = jnp.zeros(E)
    else:
        d = delay.to_decimal(u.ms) if isinstance(delay, u.Quantity) else jnp.asarray(delay)
        d_ms = jnp.broadcast_to(jnp.asarray(d), (E,))
    return u.Quantity(d_ms, unit=u.ms)


def _set_event_delay(proj, q) -> u.Quantity:
    """Set an EventProjection's homogeneous axonal delay, rounded to the ``dt`` grid.

    NEST stores delays as an integer multiple of the resolution; matching that, the
    requested delay is rounded to the nearest ``dt`` and the ``InputDelay`` seam is
    rebuilt (or cleared when the rounded delay is ``0``). Returns the grid-rounded
    delay actually applied.
    """
    dt = brainstate.environ.get_dt()
    dt_ms = float(u.Quantity(dt).to_decimal(u.ms))
    q_ms = float(u.Quantity(q).to_decimal(u.ms)) if isinstance(q, u.Quantity) else float(q)
    steps = int(round(q_ms / dt_ms))
    rounded = (steps * dt_ms) * u.ms
    if steps <= 0:
        proj._delay = None
        proj.delay_seam = None
    else:
        proj._delay = rounded
        proj.delay_seam = InputDelay((proj._n_pre_pop,), rounded)
        brainstate.nn.init_all_states(proj.delay_seam)
    return rounded


# ---------------------------------------------------------------------------
# Write-back closures (guarded by SynapseCollection.set)
# ---------------------------------------------------------------------------

def _make_event_writer(target, order):
    """Build a per-edge weight writer for an :class:`EventProjection` storage mode.

    ``order`` maps a canonical edge position (the order ``realized_edges`` returns)
    to a native storage position, so a write at canonical indices lands on the
    right ``_data`` / matrix cell / shared scalar.
    """
    kind = target[0]
    if kind == 'data':                     # sparse / receptor: per-edge _data State
        state = target[1]

        def write(local_idx, mantissa_values):
            pos = jnp.asarray(order[np.asarray(local_idx)])
            state.value = state.value.at[pos].set(jnp.asarray(mantissa_values))
        return write

    if kind == 'dense':                    # dense matrix cell (row, col) per edge
        wstate, rows, cols, w_unit = target[1], target[2], target[3], target[4]

        def write(local_idx, mantissa_values):
            pos = order[np.asarray(local_idx)]
            r = jnp.asarray(rows[pos]); c = jnp.asarray(cols[pos])
            mant, _ = _split_weight(wstate.value)
            mant = jnp.asarray(mant).at[r, c].set(jnp.asarray(mantissa_values))
            wstate.value = _as_quantity(mant, w_unit)
        return write

    # 'scalar' — one_to_one shares a single weight; only a uniform (scalar) set
    # reaches here (SynapseCollection refuses a per-edge set on a homogeneous proj).
    proj, w_unit = target[1], target[2]

    def write(local_idx, mantissa_values):
        v = np.asarray(mantissa_values).reshape(-1)
        scalar = jnp.asarray(v[0] if v.size else 0.0)
        proj._weight = _as_quantity(scalar, w_unit)
    return write


def _make_plastic_writer(proj, order):
    """Build a per-edge weight writer for a plastic projection (static rule only).

    Writes the live ``weight`` State when it has been allocated (post-``init_state``),
    else the pre-simulation ``_w_init`` array, so an introspection-time edit survives
    into the run. A homogeneous (0-d) weight takes the uniform scalar.
    """
    def write(local_idx, mantissa_values):
        pos = jnp.asarray(order[np.asarray(local_idx)])
        w_state = getattr(proj, 'weight', None)
        if isinstance(w_state, brainstate.State):
            cur = jnp.asarray(u.get_mantissa(w_state.value))
            if cur.ndim == 0:
                v = np.asarray(mantissa_values).reshape(-1)
                w_state.value = jnp.asarray(v[0] if v.size else 0.0)
            else:
                w_state.value = cur.at[pos].set(jnp.asarray(mantissa_values))
        else:
            init = jnp.asarray(proj._w_init)
            if init.ndim == 0:
                v = np.asarray(mantissa_values).reshape(-1)
                proj._w_init = jnp.asarray(v[0] if v.size else 0.0)
            else:
                proj._w_init = init.at[pos].set(jnp.asarray(mantissa_values))
    return write


# ---------------------------------------------------------------------------
# Per-projection edge enumeration
# ---------------------------------------------------------------------------

def event_proj_edges(proj) -> 'ProjEdges':
    """Enumerate an :class:`EventProjection`'s realized edges as a :class:`ProjEdges`.

    Handles all four comm modes (dense matrix, sparse CSR, per-receptor scatter,
    element-wise ``one_to_one``); maps segment-local storage indices to
    population-local ``source`` / ``target`` and sorts to the canonical edge order.
    """
    from brainpy_state._nest_network.projections import _DenseMatMul

    pre_map = np.asarray(proj.pre_local_idx)
    post_map = np.asarray(proj.post_local_idx)

    if proj._receptor:
        comm = proj.comm
        pre_seg = np.asarray(comm._pre_idx)
        post_seg = np.asarray(comm._post_idx)
        w_mant = np.asarray(comm._data.value)
        w_unit = comm._unit
        is_homogeneous = False
        target = ('data', comm._data)
    elif proj._one_to_one:
        n = int(pre_map.shape[0])
        pre_seg = np.arange(n)
        post_seg = np.arange(n)
        w_m, w_unit = _split_weight(proj._weight)
        is_homogeneous = (w_m.ndim == 0)
        w_mant = np.broadcast_to(np.asarray(w_m), (n,))
        target = ('scalar', proj, w_unit)
    elif isinstance(proj.comm, _DenseMatMul):
        W = proj._W.value
        w_m, w_unit = _split_weight(W)
        Wm = np.asarray(w_m)
        rows, cols = np.nonzero(Wm)
        pre_seg = rows
        post_seg = cols
        w_mant = Wm[rows, cols]
        is_homogeneous = False
        target = ('dense', proj._W, rows, cols, w_unit)
    else:                                  # sparse CSR event matmul
        comm = proj.comm
        indptr = np.asarray(comm._indptr)
        pre_seg = np.repeat(np.arange(indptr.shape[0] - 1), np.diff(indptr))
        post_seg = np.asarray(comm._indices)
        w_mant = np.asarray(comm._data.value)
        w_unit = comm._unit
        is_homogeneous = False
        target = ('data', comm._data)

    source = pre_map[pre_seg]
    dst = post_map[post_seg]
    order = canonical_order(source)
    source = source[order]
    dst = dst[order]
    w_mant = np.asarray(w_mant)[order]
    E = int(source.shape[0])

    weight = _as_quantity(jnp.asarray(w_mant), w_unit)
    delay = _broadcast_delay(getattr(proj, '_delay', None), E)
    write_weight = _make_event_writer(target, order)
    write_delay = lambda q: _set_event_delay(proj, q)
    return ProjEdges(source, dst, weight, delay, is_homogeneous, False,
                     'static_synapse', write_weight, write_delay)


def plastic_proj_edges(proj) -> 'ProjEdges':
    """Enumerate a plastic projection's realized edges as a :class:`ProjEdges`.

    Reads the live ``weight`` State when allocated (post-simulation evolved weights),
    else the pre-simulation ``_w_init``; a weight-evolving rule (anything but the
    static family) exposes ``write_weight = None`` so :meth:`SynapseCollection.set`
    refuses to overwrite rule-managed weights.
    """
    pre_map = np.asarray(proj.pre_local_idx)
    post_map = np.asarray(proj.post_local_idx)
    pre_seg = np.asarray(proj._pre_idx)
    post_seg = np.asarray(proj._post_idx)
    rule = proj.rule
    w_unit = proj._w_unit

    w_state = getattr(proj, 'weight', None)
    if isinstance(w_state, brainstate.State):
        wm = np.asarray(u.get_mantissa(w_state.value))
    else:
        wm = np.asarray(proj._w_init)
    E = int(pre_seg.shape[0])
    w_mant = np.broadcast_to(wm, (E,)) if wm.ndim == 0 else wm

    source = pre_map[pre_seg]
    dst = post_map[post_seg]
    order = canonical_order(source)
    source = source[order]
    dst = dst[order]
    w_mant = np.asarray(w_mant)[order]

    weight = _as_quantity(jnp.asarray(w_mant), w_unit)
    delay = _broadcast_delay(getattr(rule, 'delay', None), E)
    model_name = type(rule).__name__
    evolves = model_name not in _STATIC_PLASTIC_MODELS
    write_weight = None if evolves else _make_plastic_writer(proj, order)
    return ProjEdges(source, dst, weight, delay,
                     bool(rule.is_homogeneous_weight), True, model_name,
                     write_weight, None)


# ---------------------------------------------------------------------------
# SynapseCollection — a filtered, lazy view across projections
# ---------------------------------------------------------------------------

@dataclasses.dataclass(frozen=True)
class _SpannedProj:
    """One projection's contribution to a :class:`SynapseCollection`.

    Holds only the projection handle and the **canonical edge indices** kept by the
    query (plus the cached population-local ``source`` / ``target`` of those edges).
    Weight / delay are never copied — they are re-read live from ``proj`` on demand.
    """
    proj: object
    kept: np.ndarray
    source: np.ndarray
    target: np.ndarray


def _membership(view):
    """Build ``{id(population): set(local idx)}`` from a ``NodeView`` (``None`` -> no filter)."""
    if view is None:
        return None
    if hasattr(view, 'segments'):
        segments = view.segments
    else:                                  # a bare population module
        from brainpy_state._nest_network.nodeview import NodeView
        segments = NodeView.of(view).segments
    members = {}
    for seg in segments:
        members.setdefault(id(seg.population), set()).update(
            int(i) for i in np.asarray(seg.indices).tolist())
    return members


def collect_connections(connections, source=None, target=None, synapse=None) -> 'SynapseCollection':
    """Build a :class:`SynapseCollection` over a Simulator's connection registry.

    Parameters
    ----------
    connections : list of tuple
        The Simulator's ``(pre_pop, post_pop, model_name, proj)`` registry, in
        registration order.
    source, target : NodeView or None
        Restrict to edges whose pre / post neuron lies in the given view (matched by
        population identity **and** population-local index); ``None`` keeps all.
    synapse : str or None
        Restrict to projections whose synapse-model name equals this string.

    Returns
    -------
    SynapseCollection
        A lazy, filtered view; empty when nothing matches.
    """
    src_m = _membership(source)
    tgt_m = _membership(target)
    spanned = []
    for pre_pop, post_pop, model_name, proj in connections:
        if synapse is not None and model_name != synapse:
            continue
        if src_m is not None and id(pre_pop) not in src_m:
            continue
        if tgt_m is not None and id(post_pop) not in tgt_m:
            continue
        edges = proj.realized_edges()
        src = np.asarray(edges.source)
        tgt = np.asarray(edges.target)
        keep = np.ones(src.shape[0], dtype=bool)
        if src_m is not None:
            keep &= np.isin(src, np.fromiter(src_m[id(pre_pop)], dtype=int))
        if tgt_m is not None:
            keep &= np.isin(tgt, np.fromiter(tgt_m[id(post_pop)], dtype=int))
        kept = np.nonzero(keep)[0]
        if kept.shape[0] == 0:
            continue
        spanned.append(_SpannedProj(proj, kept, src[kept], tgt[kept]))
    return SynapseCollection(spanned)


def _to_unit_mantissa(value, unit):
    """Convert ``value`` to a bare mantissa in ``unit`` (a Quantity is unit-checked)."""
    if isinstance(value, u.Quantity):
        return jnp.asarray(value.mantissa) if unit is u.UNITLESS else jnp.asarray(value.to_decimal(unit))
    return jnp.asarray(value)


class SynapseCollection:
    """A filtered, lazy view over realized synapses (NEST ``SynapseCollection``).

    Returned by :meth:`~brainpy_state._nest_network.simulator.Simulator.get_connections`.
    It stores, per spanned projection, only the kept edge indices and their
    population-local ``source`` / ``target`` — **no copy** of weight or delay. Each
    :meth:`get` re-reads weight / delay from the live projection, so a query made
    before :meth:`~brainpy_state._nest_network.simulator.Simulator.simulate` still
    reflects post-simulation evolved plastic weights.

    Methods mirror the NEST object: :meth:`get` (one key -> array, a list of keys
    -> dict), :meth:`set` (``'weight'`` / ``'delay'`` write-back, guarded), the
    :attr:`source` / :attr:`target` arrays, ``len()``, iteration over
    ``(source, target)`` pairs, and ``repr``.

    Notes
    -----
    ``source`` / ``target`` are **population-local** indices (``brainpy.state`` has
    no global node-id space). ``set('weight', value)`` accepts a scalar (broadcast
    to every edge) or a full per-edge array; a per-edge array is refused on a
    homogeneous-weight projection, and any write is refused on a projection whose
    plastic rule evolves the weight. ``set('delay', scalar)`` is homogeneous per
    projection (grid-rounded to ``dt``) and applies only to the static event path.
    """
    __module__ = 'brainpy.state'

    def __init__(self, spanned):
        self._spanned = list(spanned)

    def __len__(self) -> int:
        return int(sum(sp.source.shape[0] for sp in self._spanned))

    @property
    def source(self) -> np.ndarray:
        """Population-local source index per edge (``(N,)`` int array)."""
        if not self._spanned:
            return np.zeros(0, dtype=int)
        return np.concatenate([sp.source for sp in self._spanned])

    @property
    def target(self) -> np.ndarray:
        """Population-local target index per edge (``(N,)`` int array)."""
        if not self._spanned:
            return np.zeros(0, dtype=int)
        return np.concatenate([sp.target for sp in self._spanned])

[docs] def get(self, key): """Read a connection attribute (or several). Parameters ---------- key : str or list of str One of ``'source'``, ``'target'``, ``'weight'``, ``'delay'`` — or a list of them. Returns ------- numpy.ndarray or brainunit.Quantity or dict A single array/Quantity for a string key (``weight`` / ``delay`` carry their unit and are read live), or a ``{key: value}`` dict for a list. """ if isinstance(key, (list, tuple)): return {k: self.get(k) for k in key} if key == 'source': return self.source if key == 'target': return self.target if key == 'weight': return self._gather_live('weight', u.pA) if key == 'delay': return self._gather_live('delay', u.ms) raise KeyError(f"unknown connection key {key!r}; expected one of " "'source', 'target', 'weight', 'delay'")
def _gather_live(self, attr, default_unit): """Concatenate a live per-edge Quantity across spanned projections.""" if not self._spanned: return u.Quantity(jnp.zeros(0), unit=default_unit) quantities = [getattr(sp.proj.realized_edges(), attr) for sp in self._spanned] unit = u.get_unit(quantities[0]) parts = [] for q, sp in zip(quantities, self._spanned): mant = (jnp.asarray(u.get_mantissa(q)) if unit is u.UNITLESS else jnp.asarray(u.Quantity(q).to_decimal(unit))) parts.append(mant[sp.kept]) mant = jnp.concatenate(parts) return mant if unit is u.UNITLESS else u.Quantity(mant, unit=unit)
[docs] def set(self, key, value): """Write a connection attribute back into the live projections. Parameters ---------- key : str ``'weight'`` or ``'delay'`` (``'source'`` / ``'target'`` are read-only). value : scalar, array, or brainunit.Quantity For ``'weight'``: a scalar broadcasts to every edge, a per-edge array sets each edge. For ``'delay'``: a scalar homogeneous delay per projection. Raises ------ ValueError A per-edge weight on a homogeneous-weight projection; any weight write on a weight-evolving plastic projection; or a delay write on a plastic projection. KeyError A key other than ``'weight'`` / ``'delay'``. """ if key == 'weight': self._set_weight(value) elif key == 'delay': self._set_delay(value) else: raise KeyError(f"cannot set {key!r}; only 'weight' and 'delay' are settable")
def _set_weight(self, value): is_scalar = (jnp.asarray(u.get_mantissa(value)).ndim == 0) edges_list = [sp.proj.realized_edges() for sp in self._spanned] # Validate every spanned projection before writing any (all-or-nothing). for edges in edges_list: if edges.write_weight is None: raise ValueError( f"weight of the '{edges.model_name}' projection is managed by its " "plastic rule and is overwritten during simulate(); set it through " "the synapse spec at connect() time, not SynapseCollection.set().") if (not is_scalar) and edges.is_homogeneous_weight: raise ValueError( f"cannot set a per-edge weight on the homogeneous-weight projection " f"'{edges.model_name}' (it shares a single scalar weight); pass a " "scalar to set them all, or rebuild it with a per-edge rule.") offset = 0 for sp, edges in zip(self._spanned, edges_list): m = int(sp.kept.shape[0]) unit = u.get_unit(edges.weight) chunk = value if is_scalar else value[offset:offset + m] mant = jnp.broadcast_to(_to_unit_mantissa(chunk, unit), (m,)) edges.write_weight(sp.kept, mant) offset += m def _set_delay(self, value): edges_list = [sp.proj.realized_edges() for sp in self._spanned] for edges in edges_list: if edges.write_delay is None: raise ValueError( f"delay of the '{edges.model_name}' projection is not settable via " "SynapseCollection (a plastic delay is fixed at connect() time).") for edges in edges_list: edges.write_delay(value) def __iter__(self): for s, t in zip(self.source.tolist(), self.target.tolist()): yield (s, t) def __repr__(self) -> str: return (f'SynapseCollection({len(self)} edges over ' f'{len(self._spanned)} projection(s))')