Source code for brainpy_state._nest.mip_generator

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


import math

import brainstate
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice

__all__ = [
    'mip_generator',
]

_UNSET = object()


class mip_generator(NESTDevice):
    r"""Correlated spike trains from a Multiple Interaction Process (MIP).

    ``mip_generator`` reproduces NEST's ``mip_generator`` device by combining
    one shared parent Poisson process with independent copy operations for
    each child output train.

    **1. Parent-child process model and derivation**

    Let :math:`r = \mathrm{rate}` in spikes/s and simulation step
    :math:`\Delta t` in ms. For each step :math:`n`:

    .. math::

       N_n \sim \mathrm{Poisson}(\lambda), \qquad
       \lambda = r \, \Delta t / 1000.

    For each child train :math:`i \in \{1,\dots,M\}` and each parent spike
    :math:`m \in \{1,\dots,N_n\}`, draw
    :math:`B_{i,m} \sim \mathrm{Bernoulli}(p_{\mathrm{copy}})` independently
    across :math:`i` and :math:`m`. The emitted multiplicity is

    .. math::

       K_{i,n} = \sum_{m=1}^{N_n} B_{i,m}.

    Marginally, :math:`K_{i,n}` is Poisson with parameter
    :math:`p_{\mathrm{copy}} \lambda` (Poisson thinning), so each child has
    mean rate :math:`p_{\mathrm{copy}} r`. Shared parent fluctuations induce
    cross-child covariance:

    .. math::

       \mathrm{Cov}(K_{i,n}, K_{j,n}) = p_{\mathrm{copy}}^2 \lambda,\quad
       \mathrm{Var}(K_{i,n}) = p_{\mathrm{copy}} \lambda,\quad
       \rho_{ij} = p_{\mathrm{copy}} \quad (i \neq j).

    **2. Source-equivalent sampling order and computational implications**

    The update path mirrors ``models/mip_generator.cpp``:

    1. Check whether the stimulation device is active at current step.
    2. Draw parent multiplicity from the parent Poisson process.
    3. For each output train, run explicit Bernoulli trials for each parent
       spike and count copied spikes.

    This implementation intentionally preserves NEST's explicit Bernoulli loop
    (instead of vectorised Binomial sampling). Runtime per active step is
    :math:`O(M N_n)` random comparisons in the general case, with fast paths
    for ``p_copy <= 0`` and ``p_copy >= 1``. RNG sampling uses
    ``numpy.random.Generator`` (seeded by ``rng_seed``), so draws are CPU
    NumPy-based rather than JAX-key-based.

    **3. Timing semantics and grid constraints**

    Activity follows NEST stimulation-device semantics:

    .. math::

       t_{\min} < t \le t_{\max}, \qquad
       t_{\min} = \mathrm{origin} + \mathrm{start},\quad
       t_{\max} = \mathrm{origin} + \mathrm{stop}.

    Therefore ``start`` is exclusive and ``stop`` is inclusive. Internally,
    finite times are projected to integer steps with
    :math:`\mathrm{round}(t / \Delta t)` and checked as
    ``t_min_step < curr_step <= t_max_step``. Finite ``origin``, ``start``,
    and ``stop`` must be on the simulation grid (absolute tolerance ``1e-12``
    in ``time/dt`` ratio), otherwise :class:`ValueError` is raised.

    Parameters
    ----------
    in_size : Size, optional
        Output size specification consumed by :class:`brainstate.nn.Dynamics`.
        ``self.varshape`` is derived from ``in_size`` and determines the exact
        shape of arrays emitted by :meth:`update`. Each element of
        ``self.varshape`` corresponds to one child process. Default is ``1``.
    rate : ArrayLike, optional
        Scalar parent Poisson rate :math:`r` in spikes/s (Hz), shape ``()``
        after conversion. Accepts a single-element numeric ``ArrayLike`` or a
        :class:`saiunit.Quantity` convertible to ``u.Hz``.
        Must satisfy ``rate >= 0``. Default is ``0.0 * u.Hz``.
    p_copy : ArrayLike, optional
        Scalar copy probability :math:`p_{\mathrm{copy}}` for each parent
        spike and each child process, shape ``()`` after conversion. Must be
        scalar-convertible to ``float64`` and satisfy ``0 <= p_copy <= 1``.
        Default is ``1.0``.
    start : ArrayLike, optional
        Scalar relative start time in ms (exclusive lower bound after adding
        ``origin``), shape ``()`` after conversion. Must be
        scalar-convertible to ``float64`` and, when ``dt`` is available,
        representable on the simulation grid. Default is ``0.0 * u.ms``.
    stop : ArrayLike or None, optional
        Scalar relative stop time in ms (inclusive upper bound after adding
        ``origin``), shape ``()`` after conversion. ``None`` maps to
        ``+inf``. If finite, must be scalar-convertible and
        grid-representable when ``dt`` is available. Must satisfy
        ``stop >= start`` after conversion. Default is ``None``.
    origin : ArrayLike, optional
        Scalar time offset in ms added to both ``start`` and ``stop``,
        shape ``()`` after conversion. Must be scalar-convertible and
        grid-representable when ``dt`` is available.
        Default is ``0.0 * u.ms``.
    rng_seed : int, optional
        Seed passed to :class:`numpy.random.SeedSequence` and split into two
        independent RNG streams (parent Poisson and child-copy Bernoulli).
        Default is ``0``.
    name : str or None, optional
        Optional dynamics node name passed to :class:`brainstate.nn.Dynamics`.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 20 18 22 40

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``rate``
         - ``0.0 * u.Hz``
         - :math:`r`
         - Parent Poisson intensity in spikes/s.
       * - ``p_copy``
         - ``1.0``
         - :math:`p_{\mathrm{copy}}`
         - Copy probability per parent spike and per child train.
       * - ``start``
         - ``0.0 * u.ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative exclusive lower activity bound.
       * - ``stop``
         - ``None``
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative inclusive upper activity bound; ``None`` maps to ``+\infty``.
       * - ``origin``
         - ``0.0 * u.ms``
         - :math:`t_0`
         - Time offset added to ``start`` and ``stop``.
       * - ``in_size``
         - ``1``
         - :math:`M`
         - Number/shape of child processes (``M = prod(varshape)``).
       * - ``rng_seed``
         - ``0``
         - -
         - Entropy source for parent/child RNG stream initialization.

    Raises
    ------
    ValueError
        If ``rate < 0``; if ``p_copy`` is outside ``[0, 1]``; if
        ``stop < start``; if scalar conversion fails due to non-scalar
        inputs; or if finite ``origin``/``start``/``stop`` are not multiples
        of ``dt`` when simulation resolution is available.
    TypeError
        If conversion of unitful inputs to ``u.Hz`` or ``u.ms`` is invalid.
    KeyError
        At update time, if the simulation environment does not provide
        required entries such as ``dt`` via ``brainstate.environ.get_dt()``.

    Notes
    -----
    - Outputs are multiplicities ``0, 1, 2, ...`` per discrete step, matching
      NEST ``SpikeEvent`` multiplicity semantics rather than binary spike
      flags.
    - :meth:`init_state` creates two independent RNG instances to mirror
      NEST's separation of parent and child stochastic paths.
    - :meth:`set` updates cached timing boundaries immediately when ``dt``
      is already available in ``brainstate.environ``.

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     gen = brainpy.state.mip_generator(
       ...         in_size=(2, 3),
       ...         rate=800.0 * u.Hz,
       ...         p_copy=0.25,
       ...         start=5.0 * u.ms,
       ...         stop=40.0 * u.ms,
       ...         rng_seed=7,
       ...     )
       ...     with brainstate.environ.context(t=10.0 * u.ms):
       ...         counts = gen.update()
       ...     _ = counts.shape, counts.dtype

    .. code-block:: python

       >>> import brainpy
       >>> import saiunit as u
       >>> gen = brainpy.state.mip_generator(rate=1200.0 * u.Hz, p_copy=0.1)
       >>> gen.set(start=2.0 * u.ms, stop=None, origin=1.0 * u.ms)
       >>> params = gen.get()
       >>> _ = params['rate'], params['p_copy'], params['stop']

    See Also
    --------
    poisson_generator : Independent Poisson trains without shared parent process.
    poisson_generator_ps : Precise-time Poisson generator with dead time.
    inhomogeneous_poisson_generator : Time-varying Poisson rate generator.

    References
    ----------
    .. [1] NEST source: ``models/mip_generator.h`` and
           ``models/mip_generator.cpp``.
    .. [2] NEST docs:
           https://nest-simulator.readthedocs.io/en/stable/models/mip_generator.html
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        rate: ArrayLike = 0. * u.Hz,
        p_copy: ArrayLike = 1.0,
        start: ArrayLike = 0. * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0. * u.ms,
        rng_seed: int = 0,
        name: str | None = None,
    ):
        super().__init__(in_size=in_size, name=name)

        self.rate = self._to_scalar_rate_hz(rate)
        self.p_copy = self._to_scalar_float(p_copy, name='p_copy')
        self.start = self._to_scalar_time_ms(start)
        self.stop = np.inf if stop is None else self._to_scalar_time_ms(stop)
        self.origin = self._to_scalar_time_ms(origin)
        self.rng_seed = int(rng_seed)

        self._validate_parameters(
            rate=self.rate,
            p_copy=self.p_copy,
            start=self.start,
            stop=self.stop,
        )

        self._num_targets = int(np.prod(self.varshape))
        self._dt_cache_ms = np.nan
        self._t_min_step = 0
        self._t_max_step = np.iinfo(np.int64).max
        dt_ms = self._maybe_dt_ms()
        if dt_ms is not None:
            self._refresh_timing_cache(dt_ms)

    @staticmethod
    def _to_scalar_time_ms(value: ArrayLike) -> float:
        dftype = brainstate.environ.dftype()
        if isinstance(value, u.Quantity):
            arr = np.asarray(value.to_decimal(u.ms), dtype=dftype)
        else:
            arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError('Time parameters must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _to_scalar_rate_hz(value: ArrayLike) -> float:
        dftype = brainstate.environ.dftype()
        if isinstance(value, u.Quantity):
            arr = np.asarray(value.to_decimal(u.Hz), dtype=dftype)
        else:
            arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError('rate must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _to_scalar_float(value: ArrayLike, *, name: str) -> float:
        dftype = brainstate.environ.dftype()
        arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError(f'{name} must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _validate_parameters(
        *,
        rate: float,
        p_copy: float,
        start: float,
        stop: float,
    ):
        if rate < 0.0:
            raise ValueError('Rate must be non-negative.')
        if p_copy < 0.0 or p_copy > 1.0:
            raise ValueError('Copy probability must be in [0, 1].')
        if stop < start:
            raise ValueError('stop >= start required.')

    @staticmethod
    def _time_to_step(time_ms: float, dt_ms: float) -> int:
        return int(np.rint(time_ms / dt_ms))

    @staticmethod
    def _assert_grid_time(name: str, time_ms: float, dt_ms: float):
        if not np.isfinite(time_ms):
            return
        ratio = time_ms / dt_ms
        nearest = np.rint(ratio)
        if not math.isclose(ratio, nearest, rel_tol=0.0, abs_tol=1e-12):
            raise ValueError(f'{name} must be a multiple of the simulation resolution.')

    def _dt_ms(self) -> float:
        dt = brainstate.environ.get_dt()
        return self._to_scalar_time_ms(dt)

    def _maybe_dt_ms(self) -> float | None:
        dt = brainstate.environ.get('dt', default=None)
        if dt is None:
            return None
        return self._to_scalar_time_ms(dt)

    def _current_time_ms(self) -> float:
        t = brainstate.environ.get('t', default=0. * u.ms)
        if t is None:
            return 0.0
        return self._to_scalar_time_ms(t)

    def _refresh_timing_cache(self, dt_ms: float):
        self._assert_grid_time('origin', self.origin, dt_ms)
        self._assert_grid_time('start', self.start, dt_ms)
        self._assert_grid_time('stop', self.stop, dt_ms)

        self._t_min_step = self._time_to_step(self.origin + self.start, dt_ms)
        if np.isfinite(self.stop):
            self._t_max_step = self._time_to_step(self.origin + self.stop, dt_ms)
        else:
            self._t_max_step = np.iinfo(np.int64).max
        self._dt_cache_ms = float(dt_ms)

    def _is_active(self, curr_step: int) -> bool:
        return (self._t_min_step < curr_step) and (curr_step <= self._t_max_step)

[docs] def init_state(self, batch_size: int = None, **kwargs): r"""Initialize RNG state for parent and child stochastic paths. Spawns two independent ``numpy.random.Generator`` instances from ``rng_seed`` via :class:`numpy.random.SeedSequence`, mirroring NEST's separation of parent Poisson draws and per-child Bernoulli draws. Parameters ---------- batch_size : int or None, optional Unused placeholder for :class:`brainstate.nn.Dynamics` API compatibility. Ignored by this implementation. **kwargs Additional keyword arguments accepted for API compatibility. Ignored. Raises ------ ValueError If ``rng_seed`` cannot be consumed by :class:`numpy.random.SeedSequence`. TypeError If ``rng_seed`` has an invalid type for NumPy RNG initialization. """ del batch_size, kwargs seed_seq = np.random.SeedSequence(self.rng_seed) parent_seed, child_seed = seed_seq.spawn(2) self._rng_parent = np.random.default_rng(parent_seed) self._rng_child = np.random.default_rng(child_seed)
[docs] def set( self, *, rate: ArrayLike | object = _UNSET, p_copy: ArrayLike | object = _UNSET, start: ArrayLike | object = _UNSET, stop: ArrayLike | object = _UNSET, origin: ArrayLike | object = _UNSET, ): r"""Update public generator parameters with NEST-compatible semantics. Any parameter left at the internal sentinel ``_UNSET`` retains its current value. All provided values are validated and converted before any attribute is mutated, so the generator state remains consistent on failure. If ``dt`` is currently available in ``brainstate.environ``, the cached step bounds are recomputed immediately after mutation. Parameters ---------- rate : ArrayLike or object, optional New scalar parent Poisson rate in Hz. If omitted, keep current value. Must satisfy ``rate >= 0`` after scalar conversion. p_copy : ArrayLike or object, optional New scalar copy probability in ``[0, 1]``. If omitted, keep current value. start : ArrayLike or object, optional New scalar relative start time in ms. If omitted, keep current value. stop : ArrayLike, None, or object, optional New scalar relative stop time in ms. ``None`` maps to ``+inf``. If omitted, keep current value. origin : ArrayLike or object, optional New scalar time origin in ms. If omitted, keep current value. Raises ------ ValueError If any provided parameter is non-scalar, violates parameter constraints (for example ``p_copy`` outside ``[0, 1]`` or ``stop < start``), or finite times are off the simulation grid when ``dt`` is available. TypeError If unit conversion or scalar coercion fails for provided values. """ new_rate = self.rate if rate is _UNSET else self._to_scalar_rate_hz(rate) new_p_copy = ( self.p_copy if p_copy is _UNSET else self._to_scalar_float(p_copy, name='p_copy') ) new_start = self.start if start is _UNSET else self._to_scalar_time_ms(start) if stop is _UNSET: new_stop = self.stop elif stop is None: new_stop = np.inf else: new_stop = self._to_scalar_time_ms(stop) new_origin = self.origin if origin is _UNSET else self._to_scalar_time_ms(origin) self._validate_parameters( rate=new_rate, p_copy=new_p_copy, start=new_start, stop=new_stop, ) self.rate = new_rate self.p_copy = new_p_copy self.start = new_start self.stop = new_stop self.origin = new_origin dt_ms = self._maybe_dt_ms() if dt_ms is not None: self._refresh_timing_cache(dt_ms)
[docs] def get(self) -> dict: r"""Return current public parameters as plain Python scalars. Returns ------- out : dict ``dict`` with keys ``'rate'``, ``'p_copy'``, ``'start'``, ``'stop'``, and ``'origin'``. Values are Python ``float`` in public units: Hz for ``rate`` and ms for all time fields. ``'stop'`` is ``math.inf`` if unbounded (i.e., ``stop=None`` was supplied at construction or via :meth:`set`). """ return { 'rate': float(self.rate), 'p_copy': float(self.p_copy), 'start': float(self.start), 'stop': float(self.stop), 'origin': float(self.origin), }
def _sample_parent_spikes(self, lam: float) -> int: return int(self._rng_parent.poisson(lam)) def _sample_child_spikes(self, n_parent_spikes: int) -> np.ndarray: ditype = brainstate.environ.ditype() out = np.zeros(self._num_targets, dtype=ditype) if n_parent_spikes <= 0 or self._num_targets == 0: return out if self.p_copy <= 0.0: return out if self.p_copy >= 1.0: out.fill(int(n_parent_spikes)) return out for i in range(self._num_targets): copied = np.count_nonzero(self._rng_child.random(n_parent_spikes) < self.p_copy) out[i] = int(copied) return out
[docs] def simulate(self, n_steps: int) -> np.ndarray: r"""Run ``n_steps`` simulation steps in one vectorised NumPy call. Equivalent to calling :meth:`update` in a loop with ``brainstate.environ.context(t=k*dt)`` for ``k = 0, 1, ..., n_steps-1``, but avoids per-step Python overhead by batching all random draws. Parameters ---------- n_steps : int Number of simulation steps to run. Assumes step index ``k`` corresponds to time ``t = k * dt``. Returns ------- out : numpy.ndarray Integer array of shape ``(n_steps, *self.varshape)`` with spike multiplicities per step and per child train. Notes ----- ``Binomial(n, p)`` is used in place of ``n`` independent ``Bernoulli(p)`` trials — statistically equivalent but faster. """ if not hasattr(self, '_rng_parent'): self.init_state() dt_ms = self._dt_ms() if (not np.isfinite(self._dt_cache_ms)) or ( not math.isclose(dt_ms, self._dt_cache_ms, rel_tol=0.0, abs_tol=1e-15) ): self._refresh_timing_cache(dt_ms) ditype = brainstate.environ.ditype() n = int(n_steps) steps = np.arange(n, dtype=np.int64) active = (steps > self._t_min_step) & (steps <= self._t_max_step) if self.rate <= 0.0 or self._num_targets == 0: return np.zeros((n,) + tuple(self.varshape), dtype=ditype) lam = self.rate * dt_ms / 1000.0 n_parents = self._rng_parent.poisson(lam, n).astype(np.int64) n_parents = np.where(active, n_parents, 0) if self.p_copy <= 0.0: mat = np.zeros((n, self._num_targets), dtype=ditype) elif self.p_copy >= 1.0: mat = np.broadcast_to(n_parents[:, np.newaxis], (n, self._num_targets)).copy().astype(ditype) else: mat = self._rng_child.binomial( n_parents[:, np.newaxis], self.p_copy, size=(n, self._num_targets) ).astype(ditype) return mat.reshape((n,) + tuple(self.varshape))
[docs] def update(self): r"""Advance one simulation step and emit child spike multiplicities. Executes the source-equivalent MIP sampling pipeline: lazily initialises state if needed, refreshes the timing/rate cache when ``dt`` changes, gates activity with :math:`t_{\min} < t \le t_{\max}`, draws parent spike multiplicity from :math:`\mathrm{Poisson}(r \Delta t / 1000)`, then independently copies each parent spike into each child train with probability ``p_copy``. Returns ------- out : jax.Array NumPy ``int64`` array of shape ``self.varshape``. Entries are per-step spike multiplicities for each child train. Returns all zeros when the generator is inactive, when ``rate <= 0``, or when the parent draw yields zero spikes. Raises ------ KeyError If the simulation context does not provide ``dt`` required by ``brainstate.environ.get_dt()``. ValueError If finite timing parameters are not aligned to the simulation grid after a ``dt`` change. TypeError If simulation-time values in the environment cannot be converted to scalar milliseconds. """ if not hasattr(self, '_rng_parent'): self.init_state() dt_ms = self._dt_ms() if (not np.isfinite(self._dt_cache_ms)) or ( not math.isclose(dt_ms, self._dt_cache_ms, rel_tol=0.0, abs_tol=1e-15) ): self._refresh_timing_cache(dt_ms) ditype = brainstate.environ.ditype() if self.rate <= 0.0: return np.zeros(self.varshape, dtype=ditype) curr_step = self._time_to_step(self._current_time_ms(), dt_ms) if not self._is_active(curr_step): return np.zeros(self.varshape, dtype=ditype) lam = self.rate * dt_ms / 1000.0 n_parent_spikes = self._sample_parent_spikes(lam) if n_parent_spikes <= 0: return np.zeros(self.varshape, dtype=ditype) child_counts = self._sample_child_spikes(n_parent_spikes) return child_counts.reshape(self.varshape)