# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
"""Spatial connectivity masks.
A mask is a hard candidate cutoff anchored on the *source* node: ``contains`` returns
a boolean ``(n_pre, n_post)`` matrix selecting which target nodes are eligible. The
distance kernel ``p(d)`` then applies only within the mask. Mirrors NEST's
``{"circular": {"radius": r}}`` / ``{"spherical": {"radius": r}}`` /
``{"box": {"lower_left": ll, "upper_right": ur}}``.
"""
from __future__ import annotations
import math
import brainunit as u
from brainpy_state._nest_spatial.distance import displacement, pairwise_distance
from brainpy_state._nest_spatial.layers import _LEN, _as_len
__all__ = ['circular', 'spherical', 'box', 'rectangular', 'doughnut',
'elliptical', 'ellipsoidal']
class _RadialMask:
"""Distance cutoff ``d <= radius`` (NEST ``circular`` in 2-D / ``spherical`` in 3-D)."""
__module__ = 'brainpy.state'
def __init__(self, radius):
self.radius = _as_len(radius)
def contains(self, pre_pos, post_pos):
"""Boolean ``(n_pre, n_post)``: target within ``radius`` of source (inclusive)."""
return pairwise_distance(pre_pos, post_pos) <= self.radius
class _BoxMask:
"""Axis-aligned box on the displacement ``post - pre`` (NEST ``box``)."""
__module__ = 'brainpy.state'
def __init__(self, lower_left, upper_right):
self.lower_left = _as_len(lower_left)
self.upper_right = _as_len(upper_right)
def contains(self, pre_pos, post_pos):
"""Boolean ``(n_pre, n_post)``: displacement within ``[lower_left, upper_right]``."""
disp = displacement(pre_pos, post_pos) # (n_pre, n_post, d)
ge = u.math.all(disp >= self.lower_left, axis=-1)
le = u.math.all(disp <= self.upper_right, axis=-1)
return ge & le
def _rotate_disp_2d(disp, lower_left, upper_right, az_deg):
"""Rotate a 2-D displacement about the box center by ``R(-azimuth)`` (NEST ``BoxMask<2>``)."""
c = (lower_left + upper_right) / 2.0
az = math.radians(az_deg)
cos, sin = math.cos(az), math.sin(az)
rel_x = disp[..., 0] - c[0]
rel_y = disp[..., 1] - c[1]
new_x = rel_x * cos + rel_y * sin + c[0]
new_y = -rel_x * sin + rel_y * cos + c[1]
return u.math.stack([new_x, new_y], axis=-1)
class _DoughnutMask:
"""Annulus ``inner < d <= outer`` (NEST ``doughnut``: outer ball minus inner ball)."""
__module__ = 'brainpy.state'
def __init__(self, inner_radius, outer_radius):
self.inner = _as_len(inner_radius)
self.outer = _as_len(outer_radius)
def contains(self, pre_pos, post_pos):
"""Boolean ``(n_pre, n_post)``: ``inner < d <= outer`` (inner exclusive, outer inclusive)."""
d = pairwise_distance(pre_pos, post_pos)
return (d > self.inner) & (d <= self.outer)
class _RectangularMask:
"""Axis-aligned (optionally rotated) box on the displacement ``post - pre`` (NEST ``rectangular``)."""
__module__ = 'brainpy.state'
def __init__(self, lower_left, upper_right, azimuth_angle=0.0):
self.lower_left = _as_len(lower_left)
self.upper_right = _as_len(upper_right)
self.azimuth_angle = float(azimuth_angle)
def contains(self, pre_pos, post_pos):
"""Boolean ``(n_pre, n_post)``: displacement within ``[lower_left, upper_right]`` (rotated)."""
disp = displacement(pre_pos, post_pos) # (n_pre, n_post, 2)
if self.azimuth_angle != 0.0:
disp = _rotate_disp_2d(disp, self.lower_left, self.upper_right, self.azimuth_angle)
ge = u.math.all(disp >= self.lower_left, axis=-1)
le = u.math.all(disp <= self.upper_right, axis=-1)
return ge & le
class _EllipseMask:
"""Rotated ellipse (2-D) / ellipsoid (3-D) on the displacement (NEST ``elliptical`` / ``ellipsoidal``).
A target is included iff the source-anchored displacement, after rotation into the ellipse
frame, satisfies ``sum_k (n_k / semi_k)**2 <= 1`` where ``semi = axis / 2``. Mirrors NEST's
``EllipseMask<2>``/``EllipseMask<3>`` (``mask.cpp``): the ``major``/``minor``/``polar`` arguments
are the *full* axis lengths, the angles are in degrees, and the squared coordinates make the
sign convention of the rotation immaterial.
"""
__module__ = 'brainpy.state'
def __init__(self, major_axis, minor_axis, polar_axis=None,
azimuth_angle=0.0, polar_angle=0.0, anchor=None):
self.major = _as_len(major_axis)
self.minor = _as_len(minor_axis)
self.polar = _as_len(polar_axis) if polar_axis is not None else None
self.azimuth_angle = float(azimuth_angle)
self.polar_angle = float(polar_angle)
self.anchor = _as_len(anchor) if anchor is not None else None
self.ndim = 2 if polar_axis is None else 3
def contains(self, pre_pos, post_pos):
"""Boolean ``(n_pre, n_post)``: displacement inside the (rotated) ellipse / ellipsoid."""
disp = displacement(pre_pos, post_pos) # (n_pre, n_post, d) Quantity
d = u.get_magnitude(disp.to(_LEN)) # bare micrometres
if self.anchor is not None:
d = d - u.get_magnitude(self.anchor.to(_LEN))
dx, dy = d[..., 0], d[..., 1]
az = math.radians(self.azimuth_angle)
ac, asn = math.cos(az), math.sin(az)
maj = float(u.get_magnitude(self.major.to(_LEN)))
mn = float(u.get_magnitude(self.minor.to(_LEN)))
xs, ys = 4.0 / maj ** 2, 4.0 / mn ** 2 # (n/semi)**2 = n**2 * 4/axis**2
if self.ndim == 2:
nx = dx * ac + dy * asn
ny = dx * asn - dy * ac
return (nx ** 2 * xs + ny ** 2 * ys) <= 1.0
dz = d[..., 2]
pol = math.radians(self.polar_angle)
pc, ps = math.cos(pol), math.sin(pol)
plr = float(u.get_magnitude(self.polar.to(_LEN)))
zs = 4.0 / plr ** 2
base = dx * ac + dy * asn
nx = base * pc - dz * ps
ny = dx * asn - dy * ac
nz = base * ps + dz * pc
return (nx ** 2 * xs + ny ** 2 * ys + nz ** 2 * zs) <= 1.0
[docs]
def circular(radius) -> _RadialMask:
"""Circular mask (2-D): target within ``radius`` of source.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> mask = bp.spatial.circular(0.5)
"""
return _RadialMask(radius)
[docs]
def spherical(radius) -> _RadialMask:
"""Spherical mask (3-D): target within ``radius`` of source (same cutoff as circular)."""
return _RadialMask(radius)
[docs]
def box(lower_left, upper_right) -> _BoxMask:
"""Box mask (2-D/3-D): target displacement within ``[lower_left, upper_right]``.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> mask = bp.spatial.box([-0.75, -0.75, -0.75], [0.75, 0.75, 0.75])
"""
return _BoxMask(lower_left, upper_right)
[docs]
def rectangular(lower_left, upper_right, azimuth_angle=0.0) -> _RectangularMask:
"""Rectangular mask (2-D): target displacement within ``[lower_left, upper_right]``.
Parameters
----------
lower_left, upper_right : sequence of float or Quantity
The two corners of the (axis-aligned) rectangle on the source-anchored displacement.
azimuth_angle : float, optional
Rotation of the rectangle about its center, in degrees (NEST parity). Default ``0``.
Returns
-------
_RectangularMask
A hard-cutoff mask (the 2-D analogue of :func:`box`).
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> mask = bp.spatial.rectangular([-0.5, -0.5], [0.5, 0.5], azimuth_angle=30.0)
"""
return _RectangularMask(lower_left, upper_right, azimuth_angle=azimuth_angle)
[docs]
def doughnut(inner_radius, outer_radius) -> _DoughnutMask:
"""Doughnut (annulus) mask (2-D): ``inner_radius < d <= outer_radius``.
The inner boundary is exclusive and the outer boundary inclusive (NEST's
outer-ball-minus-inner-ball ``DifferenceMask``). ``inner_radius == outer_radius`` yields an
empty mask; ``inner_radius == 0`` matches :func:`circular` except at the exact center.
Parameters
----------
inner_radius, outer_radius : float or Quantity
Inner and outer radii (length); bare floats are taken in micrometres.
Returns
-------
_DoughnutMask
A hard-cutoff annulus mask.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> mask = bp.spatial.doughnut(0.3, 0.7)
"""
return _DoughnutMask(inner_radius, outer_radius)
[docs]
def elliptical(major_axis, minor_axis, azimuth_angle=0.0, anchor=None) -> _EllipseMask:
"""Elliptical mask (2-D): displacement inside a rotated ellipse (NEST ``elliptical``).
Parameters
----------
major_axis, minor_axis : float or Quantity
Full lengths of the two principal axes (not semi-axes). ``major_axis == minor_axis``
degenerates to :func:`circular` of radius ``major_axis / 2``.
azimuth_angle : float, optional
Rotation of the ellipse about its anchor, in degrees. Default ``0``.
anchor : sequence of float or Quantity, optional
Offset of the ellipse centre from the source node (on the displacement). Default origin.
Returns
-------
_EllipseMask
A hard-cutoff elliptical mask.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> mask = bp.spatial.elliptical(4.0, 2.0, azimuth_angle=45.0)
"""
return _EllipseMask(major_axis, minor_axis, azimuth_angle=azimuth_angle, anchor=anchor)
[docs]
def ellipsoidal(major_axis, minor_axis, polar_axis, azimuth_angle=0.0,
polar_angle=0.0, anchor=None) -> _EllipseMask:
"""Ellipsoidal mask (3-D): displacement inside a rotated ellipsoid (NEST ``ellipsoidal``).
Parameters
----------
major_axis, minor_axis, polar_axis : float or Quantity
Full lengths of the three principal axes (not semi-axes). All three equal degenerates to
:func:`spherical` of radius ``major_axis / 2``.
azimuth_angle : float, optional
Rotation about the polar (z) axis, in degrees. Default ``0``.
polar_angle : float, optional
Tilt of the polar axis away from z, in degrees (applied after the azimuthal rotation,
NEST ``R_y(-polar) . R_z(-azimuth)``). Default ``0``.
anchor : sequence of float or Quantity, optional
Offset of the ellipsoid centre from the source node. Default origin.
Returns
-------
_EllipseMask
A hard-cutoff ellipsoidal mask.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> mask = bp.spatial.ellipsoidal(4.0, 2.0, 1.0, azimuth_angle=30.0, polar_angle=15.0)
"""
return _EllipseMask(major_axis, minor_axis, polar_axis=polar_axis,
azimuth_angle=azimuth_angle, polar_angle=polar_angle, anchor=anchor)