# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
"""Distance-dependent connection kernels (mirrors ``nest.spatial_distributions``).
A kernel is a callable ``p(d) -> probability`` over a (pairwise) distance. The
``distance`` sentinel mirrors NEST's ``nest.spatial.distance`` so kernels read
``gaussian(distance, std=...)``.
Kernels consume an *expression* -- by default the scalar :data:`distance`, but also the
per-axis ``distance.x/.y/.z`` or the ``source_pos``/``target_pos`` accessors. Every
expression evaluates, given the rule's bound sliced positions ``pre_pos (n_pre, d)`` /
``post_pos (n_post, d)``, to an ``(n_pre, n_post)`` grid; every kernel exposes
``_eval_pair(pre_pos, post_pos)`` returning the probability grid, which
:class:`~brainpy_state._nest_spatial.rule.SpatialConnRule` samples (zero seam change).
"""
from __future__ import annotations
import math
import jax.numpy as jnp
from jax.scipy.special import gammaln
import brainunit as u
from brainpy_state._nest_spatial.distance import displacement, pairwise_distance
from brainpy_state._nest_spatial.layers import _LEN, _as_len
__all__ = ['distance', 'pos', 'source_pos', 'target_pos',
'gaussian', 'exponential', 'gamma', 'gabor', 'gaussian2D']
_AXES = ('x', 'y', 'z')
# ---------------------------------------------------------------------------
# Expression family (axis / scalar values over the (n_pre, n_post) pair grid)
# ---------------------------------------------------------------------------
class _Expr:
"""Base spatial expression: evaluate to an ``(n_pre, n_post)`` grid from bound positions."""
__module__ = 'brainpy.state'
def _eval_pair(self, pre_pos, post_pos):
raise NotImplementedError
def _eval_nodes(self, coords):
raise ValueError('this expression has no single-node value; it is only defined '
'over connected (source, target) pairs')
class _AxisDistance(_Expr):
"""Absolute per-axis distance ``|target_a - source_a|`` (NEST ``distance.x/.y/.z``)."""
def __init__(self, axis):
self.axis = int(axis)
def _eval_pair(self, pre_pos, post_pos):
ndim = pre_pos.shape[-1]
if self.axis >= ndim:
raise ValueError(
f'distance.{_AXES[self.axis]} needs a {self.axis + 1}-D layer, got {ndim}-D')
disp = displacement(pre_pos, post_pos) # (n_pre, n_post, d)
return u.math.abs(disp[..., self.axis])
def __repr__(self):
return f'spatial.distance.{_AXES[self.axis]}'
class _DistanceSentinel(_Expr):
"""Pairwise Euclidean distance in ``gaussian(distance, std=...)`` (NEST ``spatial.distance``)."""
__module__ = 'brainpy.state'
def _eval_pair(self, pre_pos, post_pos):
return pairwise_distance(pre_pos, post_pos)
@property
def x(self):
"""Per-axis absolute distance on the x-axis (NEST ``distance.x``)."""
return _AxisDistance(0)
@property
def y(self):
"""Per-axis absolute distance on the y-axis (NEST ``distance.y``)."""
return _AxisDistance(1)
@property
def z(self):
"""Per-axis absolute distance on the z-axis (NEST ``distance.z``)."""
return _AxisDistance(2)
def __repr__(self):
return 'spatial.distance'
#: Singleton representing the pairwise Euclidean distance between two nodes.
distance = _DistanceSentinel()
class _SourcePos(_Expr):
"""Source-node position on an axis, broadcast over targets (NEST ``source_pos.x/.y/.z``)."""
def __init__(self, axis):
self.axis = int(axis)
def _eval_pair(self, pre_pos, post_pos):
if self.axis >= pre_pos.shape[-1]:
raise ValueError(f'source_pos.{_AXES[self.axis]} needs a {self.axis + 1}-D layer')
col = pre_pos[:, self.axis][:, None] # (n_pre, 1)
return u.math.broadcast_to(col, (pre_pos.shape[0], post_pos.shape[0]))
def __repr__(self):
return f'spatial.source_pos.{_AXES[self.axis]}'
class _TargetPos(_Expr):
"""Target-node position on an axis, broadcast over sources (NEST ``target_pos.x/.y/.z``)."""
def __init__(self, axis):
self.axis = int(axis)
def _eval_pair(self, pre_pos, post_pos):
if self.axis >= post_pos.shape[-1]:
raise ValueError(f'target_pos.{_AXES[self.axis]} needs a {self.axis + 1}-D layer')
row = post_pos[:, self.axis][None, :] # (1, n_post)
return u.math.broadcast_to(row, (pre_pos.shape[0], post_pos.shape[0]))
def __repr__(self):
return f'spatial.target_pos.{_AXES[self.axis]}'
class _NodePos(_Expr):
"""Single-node position on an axis (NEST ``pos.x/.y/.z``); invalid in the connect path."""
def __init__(self, axis):
self.axis = int(axis)
def _eval_pair(self, pre_pos, post_pos):
raise ValueError(
'pos.{a} is a single-node position parameter and cannot be used when connecting; '
'use source_pos.{a} / target_pos.{a} (two-node) or distance.{a}'.format(
a=_AXES[self.axis]))
def _eval_nodes(self, coords):
if self.axis >= coords.shape[-1]:
raise ValueError(f'pos.{_AXES[self.axis]} needs a {self.axis + 1}-D layer')
return coords[:, self.axis]
def __repr__(self):
return f'spatial.pos.{_AXES[self.axis]}'
class _AxisHolder:
"""Exposes ``.x/.y/.z`` building a given per-axis expression class (NEST ``pos`` etc.)."""
__module__ = 'brainpy.state'
def __init__(self, factory, name):
self._factory = factory
self._name = name
@property
def x(self):
"""Per-axis expression on the x-axis."""
return self._factory(0)
@property
def y(self):
"""Per-axis expression on the y-axis."""
return self._factory(1)
@property
def z(self):
"""Per-axis expression on the z-axis."""
return self._factory(2)
def __repr__(self):
return f'spatial.{self._name}'
#: Per-node position accessors (NEST ``nest.spatial.pos`` / ``source_pos`` / ``target_pos``).
pos = _AxisHolder(_NodePos, 'pos')
source_pos = _AxisHolder(_SourcePos, 'source_pos')
target_pos = _AxisHolder(_TargetPos, 'target_pos')
def _as_input(x):
"""Validate a kernel input expression (anything evaluating over a pair grid)."""
if not hasattr(x, '_eval_pair'):
raise ValueError(
'kernel input must be a spatial expression (spatial.distance, distance.x/.y/.z, '
'source_pos/target_pos.x/.y/.z)')
return x
# ---------------------------------------------------------------------------
# Kernels
# ---------------------------------------------------------------------------
class _GaussianKernel:
r"""Gaussian distance kernel ``p(d) = \exp(-(d-\mu)^2 / (2\,\mathrm{std}^2))`` (peak 1 at ``d=\mu``)."""
__module__ = 'brainpy.state'
def __init__(self, std, mean=0.0, x=distance):
self.std = _as_len(std)
self.mean = _as_len(mean)
self._input = _as_input(x)
def __call__(self, d):
r = (d - self.mean) / self.std # dimensionless ratio
return u.math.exp(-(r ** 2) / 2.0)
def _eval_pair(self, pre_pos, post_pos):
return self(self._input._eval_pair(pre_pos, post_pos))
[docs]
def gaussian(x=distance, mean=0.0, std=1.0) -> _GaussianKernel:
r"""Gaussian distance-dependent connection probability.
Returns a callable ``p(d) = exp(-(d-mean)^2 / (2 std^2))`` matching NEST's
``nest.spatial_distributions.gaussian(distance, mean, std)``.
Parameters
----------
x : object, optional
The :data:`distance` sentinel (or a per-axis expression such as ``distance.x``).
mean : float or Quantity, optional
Distribution mean (length); bare floats are taken in micrometres. Default ``0``.
std : float or Quantity, optional
Standard deviation (length); bare floats are taken in micrometres. Default ``1``.
Returns
-------
callable
``p(d)`` mapping a distance (Quantity) to a connection probability.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> import brainunit as u
>>> p = bp.spatial.gaussian(bp.spatial.distance, std=0.5)
>>> float(u.get_magnitude(p(0.0 * u.um)))
1.0
"""
return _GaussianKernel(std, mean=mean, x=x)
class _ExponentialKernel:
r"""Exponential distance kernel ``p(d) = \exp(-d / \beta)`` (peak 1 at ``d=0``)."""
__module__ = 'brainpy.state'
def __init__(self, beta, x=distance):
self.beta = _as_len(beta)
self._input = _as_input(x)
def __call__(self, d):
return u.math.exp(-(d / self.beta)) # dimensionless ratio
def _eval_pair(self, pre_pos, post_pos):
return self(self._input._eval_pair(pre_pos, post_pos))
[docs]
def exponential(x=distance, beta=1.0) -> _ExponentialKernel:
r"""Exponential distance-dependent connection probability.
Returns a callable ``p(d) = exp(-d / beta)`` matching NEST's
``nest.spatial_distributions.exponential(distance, beta)``.
Parameters
----------
x : object, optional
The :data:`distance` sentinel (or a per-axis expression such as ``distance.x``).
beta : float or Quantity, optional
Decay length; bare floats are taken in micrometres. Default ``1``.
Returns
-------
callable
``p(d)`` mapping a distance (Quantity) to a connection probability.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> import brainunit as u
>>> p = bp.spatial.exponential(bp.spatial.distance, beta=2.0)
>>> float(u.get_magnitude(p(0.0 * u.um)))
1.0
"""
return _ExponentialKernel(beta, x=x)
class _GammaKernel:
r"""Gamma distance kernel ``p(d) = d^{\kappa-1} e^{-d/\theta} / (\theta^\kappa \Gamma(\kappa))``."""
__module__ = 'brainpy.state'
def __init__(self, kappa, theta, x=distance):
self.kappa = float(kappa)
self.theta = _as_len(theta)
self._input = _as_input(x)
def __call__(self, d):
# Mirror NEST GammaParameter (bare magnitudes in the canonical length unit).
x = u.get_magnitude(d.to(_LEN))
th = float(u.get_magnitude(self.theta.to(_LEN)))
delta = jnp.exp(-self.kappa * jnp.log(th) - gammaln(self.kappa))
return x ** (self.kappa - 1.0) * jnp.exp(-x / th) * delta
def _eval_pair(self, pre_pos, post_pos):
return self(self._input._eval_pair(pre_pos, post_pos))
[docs]
def gamma(x=distance, kappa=1.0, theta=1.0) -> _GammaKernel:
r"""Gamma distance-dependent connection probability.
Returns a callable ``p(d) = d^{kappa-1} exp(-d/theta) / (theta^kappa Gamma(kappa))``
matching NEST's ``nest.spatial_distributions.gamma(distance, kappa, theta)``.
Parameters
----------
x : object, optional
The :data:`distance` sentinel (or a per-axis expression such as ``distance.x``).
kappa : float, optional
Shape parameter. Default ``1``.
theta : float or Quantity, optional
Scale parameter (length); bare floats are taken in micrometres. Default ``1``.
Returns
-------
callable
``p(d)`` mapping a distance (Quantity) to a connection probability.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> import brainunit as u
>>> p = bp.spatial.gamma(bp.spatial.distance, kappa=2.0, theta=1.5)
>>> float(u.get_magnitude(p(1.5 * u.um))) > 0.0
True
"""
return _GammaKernel(kappa, theta, x=x)
def _axis_mag(expr, pre_pos, post_pos):
"""Evaluate a two-node expression and return its bare magnitude in the canonical unit."""
return u.get_magnitude(expr._eval_pair(pre_pos, post_pos).to(_LEN))
class _GaborKernel:
r"""Rectified Gabor kernel on the (x, y) displacements (NEST ``gabor``)."""
__module__ = 'brainpy.state'
def __init__(self, x, y, theta=0.0, gamma=1.0, std=1.0, lam=1.0, psi=0.0):
self._x = _as_input(x)
self._y = _as_input(y)
self.theta = float(theta)
self.gamma = float(gamma)
self.std = _as_len(std)
self.lam = _as_len(lam)
self.psi = float(psi)
def _eval_pair(self, pre_pos, post_pos):
X = _axis_mag(self._x, pre_pos, post_pos)
Y = _axis_mag(self._y, pre_pos, post_pos)
std = float(u.get_magnitude(self.std.to(_LEN)))
lam = float(u.get_magnitude(self.lam.to(_LEN)))
c, s = math.cos(math.radians(self.theta)), math.sin(math.radians(self.theta))
xp = X * c + Y * s
yp = -X * s + Y * c
env = jnp.exp(-(self.gamma ** 2 * xp ** 2 + yp ** 2) / (2.0 * std ** 2))
carrier = jnp.maximum(jnp.cos(2.0 * jnp.pi * yp / lam + math.radians(self.psi)), 0.0)
return env * carrier
[docs]
def gabor(x=None, y=None, theta=0.0, gamma=1.0, std=1.0, lam=1.0, psi=0.0) -> _GaborKernel:
r"""Rectified-Gabor connection probability on the (x, y) displacements (NEST ``gabor``).
``p = [cos(2\pi y'/\lambda + \psi)]^+ \exp(-(\gamma^2 x'^2 + y'^2)/(2\,\mathrm{std}^2))`` with
``x' = x\cos\theta + y\sin\theta``, ``y' = -x\sin\theta + y\cos\theta`` (``\theta``, ``\psi``
in degrees). The ``x`` / ``y`` inputs default to ``distance.x`` / ``distance.y`` (the absolute
per-axis displacements), matching NEST.
Parameters
----------
x, y : object, optional
Per-axis expressions for the x / y displacement. Default ``distance.x`` / ``distance.y``.
theta : float, optional
Orientation of the profile in degrees. Default ``0``.
gamma : float, optional
Spatial aspect ratio (major/minor axis). Default ``1``.
std : float or Quantity, optional
Envelope standard deviation (length). Default ``1``.
lam : float or Quantity, optional
Wavelength of the carrier (length). Default ``1``.
psi : float, optional
Carrier phase in degrees. Default ``0``.
Returns
-------
callable
A kernel with ``_eval_pair(pre, post)`` returning the ``(n_pre, n_post)`` probability grid.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> k = bp.spatial.gabor(bp.spatial.distance.x, bp.spatial.distance.y, theta=45.0, lam=2.0)
"""
x = distance.x if x is None else x
y = distance.y if y is None else y
return _GaborKernel(x, y, theta=theta, gamma=gamma, std=std, lam=lam, psi=psi)
class _Gaussian2DKernel:
r"""Bivariate Gaussian kernel on the (x, y) displacements (NEST ``gaussian2D``)."""
__module__ = 'brainpy.state'
def __init__(self, x, y, mean_x=0.0, mean_y=0.0, std_x=1.0, std_y=1.0, rho=0.0):
self._x = _as_input(x)
self._y = _as_input(y)
self.mean_x = _as_len(mean_x)
self.mean_y = _as_len(mean_y)
self.std_x = _as_len(std_x)
self.std_y = _as_len(std_y)
self.rho = float(rho)
def _eval_pair(self, pre_pos, post_pos):
dx = _axis_mag(self._x, pre_pos, post_pos) - float(u.get_magnitude(self.mean_x.to(_LEN)))
dy = _axis_mag(self._y, pre_pos, post_pos) - float(u.get_magnitude(self.mean_y.to(_LEN)))
sx = float(u.get_magnitude(self.std_x.to(_LEN)))
sy = float(u.get_magnitude(self.std_y.to(_LEN)))
denom = 2.0 * (1.0 - self.rho ** 2)
cx = 1.0 / (denom * sx ** 2)
cy = 1.0 / (denom * sy ** 2)
cxy = 2.0 * self.rho / (denom * sx * sy)
return jnp.exp(-dx ** 2 * cx - dy ** 2 * cy + dx * dy * cxy)
[docs]
def gaussian2D(x=None, y=None, mean_x=0.0, mean_y=0.0,
std_x=1.0, std_y=1.0, rho=0.0) -> _Gaussian2DKernel:
r"""Bivariate-Gaussian connection probability on the (x, y) displacements (NEST ``gaussian2D``).
``p = exp(-(u^2 - 2\rho u v + v^2)/(2(1-\rho^2)))`` with ``u=(x-mean_x)/std_x``,
``v=(y-mean_y)/std_y``. The ``x`` / ``y`` inputs default to ``distance.x`` / ``distance.y``.
Parameters
----------
x, y : object, optional
Per-axis expressions for the x / y displacement. Default ``distance.x`` / ``distance.y``.
mean_x, mean_y : float or Quantity, optional
Means (length). Default ``0``.
std_x, std_y : float or Quantity, optional
Standard deviations (length). Default ``1``.
rho : float, optional
Correlation in ``(-1, 1)``. Default ``0``.
Returns
-------
callable
A kernel with ``_eval_pair(pre, post)`` returning the ``(n_pre, n_post)`` probability grid.
Examples
--------
.. code-block:: python
>>> from brainpy import state as bp
>>> k = bp.spatial.gaussian2D(std_x=0.5, std_y=1.0, rho=0.3)
"""
x = distance.x if x is None else x
y = distance.y if y is None else y
return _Gaussian2DKernel(x, y, mean_x=mean_x, mean_y=mean_y,
std_x=std_x, std_y=std_y, rho=rho)