Source code for brainpy_state._nest_spatial.plot

# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
"""Spatial plotting helpers (NEST ``PlotLayer`` / ``PlotTargets`` / ``PlotSources`` /
``PlotProbabilityParameter``).

matplotlib is an *optional* dependency: it is imported lazily inside each function (never at
module import) so that the rest of ``brainpy.state.spatial`` works without it. Each helper returns
the :class:`matplotlib.figure.Figure` it drew on, so callers can further annotate or save it.
"""
from __future__ import annotations

import numpy as np
import brainunit as u

from brainpy_state._nest_spatial.distance import pairwise_distance
from brainpy_state._nest_spatial.helpers import target_positions
from brainpy_state._nest_spatial.layers import _LEN, _as_len

__all__ = ['plot_layer', 'plot_targets', 'plot_sources', 'plot_probability_parameter']


def _import_mpl():
    """Import :mod:`matplotlib.pyplot`, raising a clear error if matplotlib is absent."""
    try:
        import matplotlib.pyplot as plt
    except ImportError as exc:                  # pragma: no cover - exercised via monkeypatch
        raise ImportError(
            'spatial plotting requires matplotlib; install it with `pip install matplotlib`.'
        ) from exc
    return plt


def _axes(plt, fig, ndim):
    """Return ``(fig, ax)``: reuse ``fig``'s first axis, else add one with the right projection."""
    if fig is None:
        fig = plt.figure()
    if fig.axes:
        ax = fig.axes[0]
    else:
        ax = fig.add_subplot(111, projection='3d' if ndim == 3 else None)
    return fig, ax


def _scatter(ax, coords, ndim, color, size):
    """Scatter an ``(n, ndim)`` magnitude array on a 2-D or 3-D axis."""
    ax.scatter(*(coords[:, i] for i in range(ndim)), c=color, s=size)


def _mag(coords):
    """Bare micrometre magnitudes of a coordinate Quantity."""
    return np.asarray(u.get_magnitude(coords.to(_LEN)))


[docs] def plot_layer(layer, fig=None, nodecolor='b', nodesize=20): """Scatter a layer's node positions (NEST ``PlotLayer``). Parameters ---------- layer : Layer A concrete 2-D / 3-D position layer. fig : matplotlib.figure.Figure, optional Existing figure to draw on; a new one is created when ``None``. nodecolor : color, optional Marker colour. Default ``'b'``. nodesize : float, optional Marker size. Default ``20``. Returns ------- matplotlib.figure.Figure The figure drawn on. Examples -------- .. code-block:: python >>> from brainpy import state as bp >>> fig = bp.spatial.plot_layer(bp.spatial.grid([10, 10])) # doctest: +SKIP """ plt = _import_mpl() coords = _mag(layer.coords) fig, ax = _axes(plt, fig, layer.ndim) _scatter(ax, coords, layer.ndim, nodecolor, nodesize) return fig
[docs] def plot_targets(sim, src_node, target, fig=None, src_color='red', src_size=50, tgt_color='b', tgt_size=20): """Highlight one source node's realized targets (NEST ``PlotTargets``). Parameters ---------- sim : Simulator The simulator holding the realized connections. src_node : NodeView The source node whose targets are drawn (the first node is used if it spans several). target : NodeView The candidate-target population (created with ``positions=``). fig : matplotlib.figure.Figure, optional Existing figure to draw on. src_color, tgt_color : color, optional Marker colours for the source node and its targets. src_size, tgt_size : float, optional Marker sizes for the source node and its targets. Returns ------- matplotlib.figure.Figure The figure drawn on. """ plt = _import_mpl() tgt_coords = _mag(target_positions(sim, src_node, target)[0]) src_coords = _mag(sim.get_position(src_node)) ndim = src_coords.shape[1] fig, ax = _axes(plt, fig, ndim) if tgt_coords.size: _scatter(ax, tgt_coords, ndim, tgt_color, tgt_size) _scatter(ax, src_coords, ndim, src_color, src_size) return fig
[docs] def plot_sources(sim, source, tgt_node, fig=None, src_color='b', src_size=20, tgt_color='red', tgt_size=50): """Highlight one target node's realized sources (NEST ``PlotSources``). Parameters ---------- sim : Simulator The simulator holding the realized connections. source : NodeView The candidate-source population (created with ``positions=``). tgt_node : NodeView The target node whose sources are drawn. fig : matplotlib.figure.Figure, optional Existing figure to draw on. src_color, tgt_color : color, optional Marker colours for the sources and the target node. src_size, tgt_size : float, optional Marker sizes for the sources and the target node. Returns ------- matplotlib.figure.Figure The figure drawn on. """ plt = _import_mpl() sc = sim.get_connections(source=source, target=tgt_node) src_idx = np.unique(np.asarray(sc.source)) src_coords = _mag(sim.get_position(source))[src_idx] tgt_coords = _mag(sim.get_position(tgt_node)) ndim = tgt_coords.shape[1] fig, ax = _axes(plt, fig, ndim) if src_coords.size: _scatter(ax, src_coords, ndim, src_color, src_size) _scatter(ax, tgt_coords, ndim, tgt_color, tgt_size) return fig
[docs] def plot_probability_parameter(kernel, mask=None, extent=(-0.5, 0.5, -0.5, 0.5), shape=(100, 100), fig=None, cmap='Greys'): """Heatmap of a connection kernel ``p(d)`` over a 2-D field (NEST ``PlotProbabilityParameter``). The kernel is evaluated for a single source at the origin against a regular grid of target positions spanning ``extent``. When a ``mask`` is given the probability is zeroed outside it. Parameters ---------- kernel : object A spatial kernel/expression (``_eval_pair``) or a bare callable ``p(distance)``. mask : object, optional A spatial mask whose ``contains`` zeroes the probability outside it. extent : tuple of float, optional ``(x_min, x_max, y_min, y_max)`` of the sampled field (micrometres). Default the unit box. shape : tuple of int, optional ``(nx, ny)`` sample counts. Default ``(100, 100)``. fig : matplotlib.figure.Figure, optional Existing figure to draw on. cmap : str, optional Colormap name. Default ``'Greys'``. Returns ------- matplotlib.figure.Figure The figure drawn on. """ plt = _import_mpl() x_min, x_max, y_min, y_max = extent xs = np.linspace(x_min, x_max, shape[0]) ys = np.linspace(y_min, y_max, shape[1]) grid_x, grid_y = np.meshgrid(xs, ys, indexing='xy') pre = _as_len(np.array([[0.0, 0.0]])) # single source at origin post = _as_len(np.stack([grid_x.ravel(), grid_y.ravel()], axis=-1)) if hasattr(kernel, '_eval_pair'): prob = u.get_magnitude(kernel._eval_pair(pre, post)) else: prob = u.get_magnitude(kernel(pairwise_distance(pre, post))) prob = np.asarray(prob).reshape(grid_x.shape) if mask is not None: inside = np.asarray(mask.contains(pre, post)).reshape(grid_x.shape) prob = np.where(inside, prob, 0.0) fig, ax = _axes(plt, fig, 2) ax.imshow(prob, origin='lower', extent=extent, cmap=cmap, aspect='auto') return fig