Source code for brainpy_state._nest_network.rules

# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
"""Connection rules as values, wrapping the internal connectivity samplers."""
from __future__ import annotations

import jax
import jax.numpy as jnp

from brainpy_state._nest_network.connectivity import (
    ConnSpec,
    sample_all_to_all,
    sample_one_to_one,
    sample_fixed_indegree,
    sample_pairwise_bernoulli,
    sample_fixed_total_number,
    build_pool_map,
    sample_third_factor_pairing,
)

__all__ = ['ConnRule', 'all_to_all', 'one_to_one', 'fixed_indegree',
           'pairwise_bernoulli', 'fixed_total_number',
           'third_factor_bernoulli_with_pool', 'explicit_edges']


class ConnRule:
    """Base class for connection rules.

    A rule maps ``(n_pre, n_post)`` plus sampling flags to a
    :class:`~brainpy_state._nest_network.connectivity.ConnSpec` of edge indices.
    """
    __module__ = 'brainpy.state'

    def sample(self, n_pre, n_post, *, key, pre_is_post, allow_autapses, allow_multapses) -> ConnSpec:
        raise NotImplementedError


class _AllToAll(ConnRule):
    __module__ = 'brainpy.state'

    def sample(self, n_pre, n_post, *, key, pre_is_post, allow_autapses, allow_multapses):
        return sample_all_to_all(n_pre, n_post, pre_is_post=pre_is_post,
                                 allow_autapses=allow_autapses)


class _OneToOne(ConnRule):
    __module__ = 'brainpy.state'

    def sample(self, n_pre, n_post, *, key, pre_is_post, allow_autapses, allow_multapses):
        return sample_one_to_one(n_pre, n_post)


class _FixedIndegree(ConnRule):
    """Each post-synaptic neuron receives exactly ``K`` incoming edges."""
    __module__ = 'brainpy.state'

    def __init__(self, K: int):
        self.K = int(K)

    def sample(self, n_pre, n_post, *, key, pre_is_post, allow_autapses, allow_multapses):
        return sample_fixed_indegree(n_pre, n_post, K=self.K, key=key,
                                     pre_is_post=pre_is_post,
                                     allow_autapses=allow_autapses,
                                     allow_multapses=allow_multapses)


class _PairwiseBernoulli(ConnRule):
    """Each ordered ``(pre, post)`` pair is connected independently with prob. ``p``."""
    __module__ = 'brainpy.state'

    def __init__(self, p: float):
        self.p = float(p)

    def sample(self, n_pre, n_post, *, key, pre_is_post, allow_autapses, allow_multapses):
        return sample_pairwise_bernoulli(n_pre, n_post, p=self.p, key=key,
                                         pre_is_post=pre_is_post,
                                         allow_autapses=allow_autapses,
                                         allow_multapses=allow_multapses)


class _FixedTotalNumber(ConnRule):
    """Exactly ``N`` edges drawn uniformly over the ``(pre, post)`` grid."""
    __module__ = 'brainpy.state'

    def __init__(self, N: int):
        self.N = int(N)

    def sample(self, n_pre, n_post, *, key, pre_is_post, allow_autapses, allow_multapses):
        return sample_fixed_total_number(n_pre, n_post, N=self.N, key=key,
                                         pre_is_post=pre_is_post,
                                         allow_autapses=allow_autapses,
                                         allow_multapses=allow_multapses)


all_to_all = _AllToAll()
one_to_one = _OneToOne()


[docs] def fixed_indegree(K: int) -> _FixedIndegree: """Return a fixed-indegree rule: each post neuron gets exactly ``K`` edges.""" if int(K) < 0: raise ValueError(f'K must be >= 0, got {K}') return _FixedIndegree(int(K))
[docs] def pairwise_bernoulli(p: float) -> _PairwiseBernoulli: """Return a pairwise-Bernoulli rule: connect each ``(pre, post)`` pair with prob. ``p``.""" p = float(p) if not (0.0 <= p <= 1.0): raise ValueError(f'p must be in [0, 1], got {p}') return _PairwiseBernoulli(p)
[docs] def fixed_total_number(N: int) -> _FixedTotalNumber: """Return a fixed-total-number rule: exactly ``N`` edges over the ``(pre, post)`` grid.""" if int(N) < 0: raise ValueError(f'N must be >= 0, got {N}') return _FixedTotalNumber(int(N))
class _ExplicitEdges(ConnRule): """A rule that returns a *precomputed* :class:`ConnSpec`, ignoring sampling args. Lets :meth:`~brainpy_state._nest_network.simulator.Simulator.tripartite_connect` feed a single shared edge sample into the existing :class:`~brainpy_state._nest_network.event_proj.EventProjection` / ``_connect_pair`` / ``_connect_sic`` paths -- which would otherwise re-sample their own connectivity from ``(rule, seed)``. ``sample`` returns the wrapped spec verbatim, so the key / autapse / multapse flags (already applied when the spec was built) are deliberately ignored. """ __module__ = 'brainpy.state' def __init__(self, spec: ConnSpec): self._spec = spec def sample(self, n_pre, n_post, *, key, pre_is_post, allow_autapses, allow_multapses): return self._spec
[docs] def explicit_edges(pre_idx, post_idx) -> _ExplicitEdges: """Return a rule wiring exactly the given ``(pre_idx[i], post_idx[i])`` edges. A connection rule built from a *precomputed* edge list, for topologies that are easier to enumerate directly than to express through a sampling rule (e.g. the structured inhibitory graph of a constraint-satisfaction network). The edges are handed to the same projection paths used by the sampling rules, so a single call realizes the whole topology as **one** projection -- avoiding the per-edge / per-group projection explosion (and the per-``cont()`` retrace) of many small :meth:`~brainpy.state.Simulator.connect` calls. Parameters ---------- pre_idx, post_idx : array_like of int Equal-length 1-D arrays of segment-local source/target neuron indices, in the coordinate frame of the populations passed to :meth:`~brainpy.state.Simulator.connect`. Edge ``i`` connects ``pre_idx[i] -> post_idx[i]``. Order is preserved and duplicates are kept -- the caller controls multapses. Returns ------- _ExplicitEdges A connection rule returning the precomputed edge set verbatim (the sampling ``key`` / autapse / multapse flags are ignored). Raises ------ ValueError If ``pre_idx`` / ``post_idx`` are not 1-D, differ in length, or are not integer-typed. See Also -------- brainpy.state.all_to_all brainpy.state.Simulator.connect Examples -------- .. code-block:: python >>> import numpy as np >>> import brainunit as u >>> from brainpy import state as bp >>> sim = bp.Simulator(dt=0.1 * u.ms) >>> pop = sim.create(bp.iaf_psc_exp, 5) >>> # wire 0->1, 0->4, 2->3 as one sparse projection >>> pre = np.array([0, 0, 2]) >>> post = np.array([1, 4, 3]) >>> _ = sim.connect(pop, pop, rule=bp.explicit_edges(pre, post), ... weight=-0.2 * u.pA, comm='sparse') >>> len(sim.get_connections(source=pop, target=pop)) 3 """ pre = jnp.asarray(pre_idx) post = jnp.asarray(post_idx) if pre.ndim != 1 or post.ndim != 1: raise ValueError( f'pre_idx/post_idx must be 1-D, got {pre.ndim}-D and {post.ndim}-D') if pre.shape[0] != post.shape[0]: raise ValueError( f'pre_idx/post_idx must have equal length, got {pre.shape[0]} and {post.shape[0]}') if not (jnp.issubdtype(pre.dtype, jnp.integer) and jnp.issubdtype(post.dtype, jnp.integer)): raise ValueError( f'pre_idx/post_idx must be integer arrays, got dtypes {pre.dtype} and {post.dtype}') return _ExplicitEdges(ConnSpec(pre, post, int(pre.shape[0])))
class _ThirdFactorBernoulliWithPool: """The ``third_factor_bernoulli_with_pool`` connection spec (NEST tripartite rule). Holds the conditional pairing probability and pool parameters; given a realized primary edge sample it derives the ``third_in`` (pre->astro) and ``third_out`` (astro->post) edge sets via :func:`build_pool_map` + :func:`sample_third_factor_pairing`. Built by :func:`third_factor_bernoulli_with_pool` and consumed by :meth:`~brainpy_state._nest_network.simulator.Simulator.tripartite_connect`. """ __module__ = 'brainpy.state' def __init__(self, p: float, pool_size: int, pool_type: str): self.p = float(p) self.pool_size = int(pool_size) self.pool_type = pool_type def sample_third(self, primary_spec: ConnSpec, n_post: int, n_third: int, *, key): """Derive the ``(third_in, third_out)`` :class:`ConnSpec` pair from the primary sample. Parameters ---------- primary_spec : ConnSpec The realized primary ``pre->post`` edges (segment-local indices). n_post : int Target-population size (for the pool map / block index). n_third : int Third-factor (astrocyte) population size. key : jax.Array PRNG key; split into a pool-assignment key and a pairing key. Returns ------- third_in : ConnSpec ``pre_i -> astro`` edges (segment-local), one per paired primary edge. third_out : ConnSpec ``astro -> post_j`` edges (segment-local), one per paired primary edge. """ k_pool, k_pair = jax.random.split(key) pool_map = build_pool_map(n_post, n_third, pool_size=self.pool_size, pool_type=self.pool_type, key=k_pool) tin_pre, astro, tout_post = sample_third_factor_pairing( primary_spec.pre_idx, primary_spec.post_idx, pool_map, p=self.p, key=k_pair) n = int(tin_pre.shape[0]) return ConnSpec(tin_pre, astro, n), ConnSpec(astro, tout_post, n)
[docs] def third_factor_bernoulli_with_pool(*, p: float, pool_size: int, pool_type: str): """Return a ``third_factor_bernoulli_with_pool`` spec (NEST tripartite astro-pool rule). Used as the ``third_factor_conn_spec`` of :meth:`~brainpy_state._nest_network.simulator.Simulator.tripartite_connect`: for each realized primary ``pre->post`` edge, an independent Bernoulli(``p``) trial decides whether it is paired with one astrocyte drawn from the target neuron's pool of ``pool_size`` astrocytes. Parameters ---------- p : float Conditional pairing probability ``p_third_if_primary`` in ``[0, 1]``. pool_size : int Astrocytes per target pool (``>= 1``; the upper bound ``<= n_third`` is checked when the populations are known). pool_type : {'random', 'block'} Pool-assignment scheme (see :func:`build_pool_map`). Returns ------- _ThirdFactorBernoulliWithPool The spec object. Raises ------ ValueError If ``p`` is outside ``[0, 1]``, ``pool_size < 1``, or ``pool_type`` is unknown. See Also -------- brainpy.state.Simulator.tripartite_connect Examples -------- .. code-block:: python >>> from brainpy import state as bp >>> rule = bp.third_factor_bernoulli_with_pool( ... p=0.5, pool_size=10, pool_type='random') >>> rule.p, rule.pool_size, rule.pool_type (0.5, 10, 'random') The returned spec is passed as the ``third_factor_conn_spec`` of :meth:`~brainpy.state.Simulator.tripartite_connect`. """ p = float(p) if not (0.0 <= p <= 1.0): raise ValueError(f'p must be in [0, 1], got {p}') if int(pool_size) < 1: raise ValueError(f'pool_size must be >= 1, got {pool_size}') if pool_type not in ('random', 'block'): raise ValueError(f"pool_type must be 'random' or 'block', got {pool_type!r}") return _ThirdFactorBernoulliWithPool(p, int(pool_size), pool_type)