Source code for brainpy_state._nest_spatial.helpers

# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
"""Spatial query helpers (NEST ``FindCenterElement`` / ``Distance`` / target queries).

:func:`center_element` and :func:`Distance` are pure layer-level queries. :func:`target_nodes`
and :func:`target_positions` read the *realized* adjacency back out of a built network
(via :meth:`brainpy_state._nest_network.simulator.Simulator.get_connections`), mirroring NEST's
``GetTargetNodes`` / ``GetTargetPositions``.
"""
from __future__ import annotations

import jax.numpy as jnp
import numpy as np
import brainunit as u

from brainpy_state._nest_spatial.distance import pairwise_distance
from brainpy_state._nest_spatial.layers import _LEN, _as_len

__all__ = ['center_element', 'Distance', 'nearest_element', 'select_nodes_by_mask',
           'dump_layer_nodes', 'dump_layer_connections', 'target_nodes', 'target_positions']


def _fmt(v) -> str:
    """Format a scalar at round-trip precision (clean ints, ``.12g`` floats)."""
    return f"{float(v):.12g}"


[docs] def center_element(layer) -> int: """Local index of the node nearest the layer centroid (NEST ``FindCenterElement``). Ties resolve to the lowest index (matching NEST). Parameters ---------- layer : Layer A concrete position layer. Returns ------- int The population-local index of the most central node. Examples -------- .. code-block:: python >>> from brainpy import state as bp >>> bp.spatial.center_element(bp.spatial.grid([4, 3], extent=[2.0, 1.5])) 4 """ coords = layer.coords centroid = u.math.mean(coords, axis=0) d2 = u.math.sum((coords - centroid) ** 2, axis=-1) return int(jnp.argmin(u.get_magnitude(d2))) # argmin -> first (lowest) on ties
[docs] def Distance(layer_a, layer_b): """Pairwise Euclidean distance between two layers (NEST ``Distance``). Parameters ---------- layer_a, layer_b : Layer Concrete position layers. Returns ------- Quantity ``(n_a, n_b)`` distances. """ return pairwise_distance(layer_a.coords, layer_b.coords)
[docs] def nearest_element(layer, locations, find_all=False): """Local index of the node nearest each query location (NEST ``FindNearestElement``). Parameters ---------- layer : Layer A concrete position layer. locations : sequence or Quantity A single coordinate ``(ndim,)`` or a list of coordinates ``(m, ndim)``. Bare floats are taken in micrometres. find_all : bool, optional When ``False`` (default), ties resolve to the lowest index. When ``True``, every node within ``|d - d_min| <= 1e-14 * d_min`` of the minimum is returned. Returns ------- int or list For a single location: an ``int`` (``find_all=False``) or a list of ints (``find_all=True``). For a list of locations: a list with one such entry per location. Examples -------- .. code-block:: python >>> from brainpy import state as bp >>> bp.spatial.nearest_element(bp.spatial.grid([3, 1], extent=[3.0, 1.0]), [0.9, 0.0]) 2 """ coords = u.get_magnitude(layer.coords.to(_LEN)) # (n, ndim) bare micrometres locs = u.get_magnitude(_as_len(locations).to(_LEN)) single = (locs.ndim == 1) queries = locs[None, :] if single else locs out = [] for q in queries: d = np.sqrt(((coords - q[None, :]) ** 2).sum(axis=-1)) dmin = float(d.min()) if find_all: out.append([int(i) for i in np.nonzero(d - dmin <= 1e-14 * dmin)[0]]) else: out.append(int(np.argmin(d))) # argmin -> lowest index on ties return out[0] if single else out
[docs] def select_nodes_by_mask(layer, anchor, mask): """Local indices of the nodes lying inside ``mask`` anchored at ``anchor`` (NEST ``SelectNodesByMask``). The mask is evaluated with ``anchor`` as the (single) source node and every layer node as a candidate target, so directional masks (``box`` / ``rectangular`` / rotated ellipses) respect the ``target - anchor`` displacement. Parameters ---------- layer : Layer A concrete position layer. anchor : sequence or Quantity The reference point ``(ndim,)`` the mask is centred on. Bare floats are micrometres. mask : object Any spatial mask exposing ``contains(pre_pos, post_pos) -> bool (n_pre, n_post)``. Returns ------- numpy.ndarray The population-local indices (ascending) of the selected nodes. Examples -------- .. code-block:: python >>> from brainpy import state as bp >>> layer = bp.spatial.grid([3, 3], extent=[2.0, 2.0]) >>> bp.spatial.select_nodes_by_mask(layer, [0.0, 0.0], bp.spatial.circular(0.7)).tolist() [1, 3, 4, 5, 7] """ anchor_pos = _as_len(anchor)[None, :] # (1, ndim) source inside = np.asarray(mask.contains(anchor_pos, layer.coords))[0] return np.nonzero(inside)[0]
[docs] def dump_layer_nodes(sim, pop, outname) -> str: """Write each node's local index + coordinates to ``outname`` (NEST ``DumpLayerNodes``). One whitespace-separated line per node: ``idx x y [z]`` (population-local index; coordinates in micrometre magnitude, ascending node order). Parameters ---------- sim : Simulator The simulator holding the population's positions. pop : NodeView A population/view created with ``create(positions=...)``. outname : str Path of the text file to write. Returns ------- str The written text (also persisted to ``outname``), for direct assertion. """ coords = np.asarray(u.get_magnitude(sim.get_position(pop).to(_LEN))) lines = [' '.join([str(i)] + [_fmt(v) for v in row]) for i, row in enumerate(coords)] text = '\n'.join(lines) + '\n' with open(outname, 'w') as fh: fh.write(text) return text
[docs] def dump_layer_connections(sim, source, target, outname) -> str: """Write each realized edge's endpoints, weight, delay and displacement (NEST ``DumpLayerConnections``). One whitespace-separated line per edge: ``src tgt weight delay dx dy [dz]`` where ``(dx, dy[, dz]) = target_pos - source_pos`` (the source-anchored displacement). Weight is in pA magnitude, delay in ms magnitude, coordinates in micrometres. Parameters ---------- sim : Simulator The simulator holding the realized connections and positions. source, target : NodeView The source and target populations/views (created with ``positions=...``). outname : str Path of the text file to write. Returns ------- str The written text (also persisted to ``outname``), for direct assertion. """ sc = sim.get_connections(source=source, target=target) src = np.asarray(sc.source) tgt = np.asarray(sc.target) weight = np.asarray(u.get_magnitude(sc.get('weight').to(u.pA))) delay = np.asarray(u.get_magnitude(sc.get('delay').to(u.ms))) spos = np.asarray(u.get_magnitude(sim.get_position(source).to(_LEN))) tpos = np.asarray(u.get_magnitude(sim.get_position(target).to(_LEN))) lines = [] for s, t, w, d in zip(src, tgt, weight, delay): disp = tpos[t] - spos[s] lines.append(' '.join([str(int(s)), str(int(t)), _fmt(w), _fmt(d)] + [_fmt(v) for v in disp])) text = ('\n'.join(lines) + '\n') if lines else '' with open(outname, 'w') as fh: fh.write(text) return text
[docs] def target_nodes(sim, source, target): r"""Realized target indices of each source node (NEST ``GetTargetNodes``). Reads the built network's adjacency back out (via :meth:`~brainpy_state._nest_network.simulator.Simulator.get_connections`) and groups the realized target indices by source node. Parameters ---------- sim : Simulator The simulator holding the realized connections. source : NodeView A single-segment source view; targets are grouped per node in this view's order. target : NodeView The candidate-target population view. Returns ------- list of numpy.ndarray One entry per source node (in ``source`` order): the sorted unique population-local target indices that node connects to. See Also -------- target_positions : the same query returning target coordinates. Examples -------- .. code-block:: python >>> from brainpy import state as bp >>> import brainunit as u >>> sim = bp.Simulator(dt=0.1 * u.ms) >>> pop = sim.create(bp.iaf_psc_alpha, positions=bp.spatial.grid([3, 1], extent=[3.0, 1.0])) >>> _ = sim.connect(pop, pop, ... rule=bp.spatial.spatial_pairwise_bernoulli(p=1.0, mask=bp.spatial.circular(1.2)), ... weight=1.0 * u.pA, delay=1.0 * u.ms) >>> [t.tolist() for t in bp.spatial.target_nodes(sim, pop, pop)] [[0, 1], [0, 1, 2], [1, 2]] """ sc = sim.get_connections(source=source, target=target) src = np.asarray(sc.source) tgt = np.asarray(sc.target) return [np.unique(tgt[src == int(s)]) for s in source.segments[0].indices]
[docs] def target_positions(sim, source, target): r"""Coordinates of each source node's realized targets (NEST ``GetTargetPositions``). Parameters ---------- sim : Simulator The simulator holding the realized connections and target positions. source : NodeView A single-segment source view (one entry is returned per node). target : NodeView The candidate-target population view (created with ``positions=``). Returns ------- list of Quantity One ``(k_i, ndim)`` coordinate array per source node, in ``source`` order. See Also -------- target_nodes : the underlying realized-target index query. Examples -------- .. code-block:: python >>> from brainpy import state as bp >>> import brainunit as u >>> sim = bp.Simulator(dt=0.1 * u.ms) >>> pop = sim.create(bp.iaf_psc_alpha, positions=bp.spatial.grid([3, 1], extent=[3.0, 1.0])) >>> _ = sim.connect(pop, pop, ... rule=bp.spatial.spatial_pairwise_bernoulli(p=1.0, mask=bp.spatial.circular(1.2)), ... weight=1.0 * u.pA, delay=1.0 * u.ms) >>> [tuple(p.shape) for p in bp.spatial.target_positions(sim, pop, pop)] [(2, 2), (3, 2), (2, 2)] """ coords = sim._positions[id(target.segments[0].population)] return [coords[idx] for idx in target_nodes(sim, source, target)]