# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Callable
import braintools
import brainunit as u
import brainstate
from brainstate.nn import Param
from ._base import NeuralMassDynamics
from .noise import Noise
from .typing import Parameter
__all__ = [
'CoombesByrneStep',
]
class CoombesByrneStep(NeuralMassDynamics):
r"""Coombes-Byrne next-generation neural mass model (2D).
Exact mean-field reduction of an infinite population of all-to-all coupled
quadratic-integrate-and-fire (QIF) / :math:`\theta`-neurons *with
conductance-based synapses*, obtained through the Ott-Antonsen ansatz [1]_.
Like the Montbrio-Pazo-Roxin model (:class:`~brainmass.MontbrioPazoRoxinStep`)
it tracks the population firing rate :math:`r(t)` and mean membrane potential
:math:`v(t)`, but it adds a synaptic conductance proportional to the firing
rate, :math:`g = \kappa\,\pi\,r`, which couples reciprocally into both
equations:
.. math::
\begin{aligned}
\dot r(t) &= \frac{\Delta}{\pi} + 2\,v\,r - g\,r, \\
\dot v(t) &= v^2 - (\pi r)^2 + \eta + (v_{\mathrm{syn}} - v)\,g + I(t),
\end{aligned}
with :math:`g = \kappa\,\pi\,r`. Here :math:`\Delta` is the half-width at
half-maximum of the Lorentzian background-excitability distribution,
:math:`\eta` the mean excitability, :math:`\kappa` the synaptic conductance
scale, :math:`v_{\mathrm{syn}}` the synaptic reversal potential, and
:math:`I(t)` an external/coupling input to the mean potential.
The conductance term makes the rate equation quadratically damped in
:math:`r` (the :math:`-g\,r = -\kappa\pi r^2` term), giving richer dynamics
than the standard QIF mean field.
Parameters
----------
in_size : brainstate.typing.Size
Spatial shape of the population. An ``int`` or tuple of ``int``; all
parameters are broadcastable to this shape.
Delta : Parameter, optional
HWHM of the Lorentzian excitability distribution (dimensionless).
Default is ``1.0``.
eta : Parameter, optional
Mean background excitability (dimensionless). Default is ``2.0``.
k : Parameter, optional
Synaptic conductance scaling :math:`\kappa` (dimensionless). Setting
``k = 0`` removes the conductance and recovers the Montbrio-Pazo-Roxin
field with ``J = 0``. Default is ``1.0``.
v_syn : Parameter, optional
Synaptic reversal potential :math:`v_{\mathrm{syn}}` (dimensionless).
Default is ``-4.0``.
init_r : Callable, optional
Initializer for the firing-rate state ``r``. Default is
``braintools.init.Constant(0.1)``.
init_v : Callable, optional
Initializer for the mean-potential state ``v``. Default is
``braintools.init.Constant(0.0)``.
noise_r : Noise or None, optional
Additive noise process for the rate dynamics. If provided, its output is
added to ``r_inp`` at each update. Default is ``None``.
noise_v : Noise or None, optional
Additive noise process for the potential dynamics. If provided, its
output is added to ``v_inp`` at each update. Default is ``None``.
method : str, optional
Integration method. Either ``'exp_euler'`` (default) or any method in
``braintools.quad`` (e.g. ``'rk4'``, ``'rk2'``, ``'heun'``).
Attributes
----------
r : brainstate.HiddenState
Population firing rate (dimensionless). Shape ``(batch?,) + in_size``.
v : brainstate.HiddenState
Population mean membrane potential (dimensionless).
Notes
-----
- State variables are dimensionless; the per-variable right-hand sides
returned by :meth:`dr` / :meth:`dv` carry unit ``1/ms`` so an
exponential-Euler step with ``dt`` in milliseconds is consistent (the same
convention used by :class:`~brainmass.FitzHughNagumoStep`).
- **Relationship to Montbrio-Pazo-Roxin.** With :math:`\kappa = 0` the
conductance :math:`g` vanishes and the equations collapse to
:math:`\dot r = \Delta/\pi + 2vr`, :math:`\dot v = v^2 - (\pi r)^2 + \eta + I`,
i.e. :class:`~brainmass.MontbrioPazoRoxinStep` with recurrent coupling
``J = 0`` (at unit time constant). The conductance regime (``k > 0``) is
what distinguishes the next-generation mass.
- The model can equivalently be written in the complex Kuramoto-Daido form
via :math:`Z = (1 - \bar W)/(1 + \bar W)` with :math:`\bar W = \pi r - i v`;
this implementation uses the real ``(r, v)`` coordinates.
References
----------
.. [1] S. Coombes and A. Byrne (2019). Next generation neural mass models.
In *Nonlinear Dynamics in Computational Neuroscience*, pp. 1-16. Springer.
https://doi.org/10.1007/978-3-319-71048-8_1
.. [2] E. Montbrió, D. Pazó, A. Roxin (2015). Macroscopic description for
networks of spiking neurons. Physical Review X, 5:021028.
Examples
--------
.. code-block:: python
>>> import brainmass
>>> import brainstate
>>> import brainunit as u
>>> model = brainmass.CoombesByrneStep(in_size=1)
>>> _ = brainstate.nn.init_all_states(model)
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... r = model.update()
>>> r.shape
(1,)
"""
__module__ = 'brainmass'
[docs]
def __init__(
self,
in_size: brainstate.typing.Size,
# model parameters
Delta: Parameter = 1.0,
eta: Parameter = 2.0,
k: Parameter = 1.0,
v_syn: Parameter = -4.0,
# initializers / noise
init_r: Callable = braintools.init.Constant(0.1),
init_v: Callable = braintools.init.Constant(0.0),
noise_r: Noise = None,
noise_v: Noise = None,
method: str = 'exp_euler',
):
super().__init__(in_size)
# the HWHM of the Lorentzian excitability distribution
self.Delta = Param.init(Delta, self.varshape)
# the mean background excitability
self.eta = Param.init(eta, self.varshape)
# the synaptic conductance scaling (kappa)
self.k = Param.init(k, self.varshape)
# the synaptic reversal potential
self.v_syn = Param.init(v_syn, self.varshape)
# initializers and noise
assert callable(init_r), 'init_r must be callable'
assert callable(init_v), 'init_v must be callable'
assert isinstance(noise_r, Noise) or noise_r is None, 'noise_r must be a Noise instance or None'
assert isinstance(noise_v, Noise) or noise_v is None, 'noise_v must be a Noise instance or None'
self.init_r = init_r
self.init_v = init_v
self.noise_r = noise_r
self.noise_v = noise_v
self.method = method
[docs]
def init_state(self, batch_size=None, **kwargs):
"""Allocate firing-rate and mean-potential states.
Parameters
----------
batch_size : int or None, optional
Optional leading batch dimension. If ``None``, no batch dimension is
used. Default is ``None``.
"""
self.r = brainstate.HiddenState.init(self.init_r, self.varshape, batch_size)
self.v = brainstate.HiddenState.init(self.init_v, self.varshape, batch_size)
def _g(self, r):
"""Synaptic conductance ``g = k * pi * r`` (dimensionless)."""
return self.k.value() * u.math.pi * r
[docs]
def dr(self, r, v, r_ext):
"""Right-hand side for the firing rate ``r``.
Parameters
----------
r : array-like
Current firing rate (dimensionless).
v : array-like
Current mean membrane potential (dimensionless), broadcastable to ``r``.
r_ext : array-like or scalar
External input to the rate equation (includes noise if enabled).
Returns
-------
array-like
Time derivative ``dr/dt`` with unit ``1/ms``.
"""
Delta = self.Delta.value()
g = self._g(r)
return (Delta / u.math.pi + 2.0 * v * r - g * r + r_ext) / u.ms
[docs]
def dv(self, v, r, v_ext):
"""Right-hand side for the mean membrane potential ``v``.
Parameters
----------
v : array-like
Current mean membrane potential (dimensionless).
r : array-like
Current firing rate (dimensionless), broadcastable to ``v``.
v_ext : array-like or scalar
External input to the potential equation (includes noise/coupling if
enabled).
Returns
-------
array-like
Time derivative ``dv/dt`` with unit ``1/ms``.
"""
eta = self.eta.value()
v_syn = self.v_syn.value()
g = self._g(r)
return (v ** 2 - (u.math.pi * r) ** 2 + eta + (v_syn - v) * g + v_ext) / u.ms
[docs]
def derivative(self, state, t, r_ext, v_ext):
r, v = state
drdt = self.dr(r, v, r_ext)
dvdt = self.dv(v, r, v_ext)
return drdt, dvdt
[docs]
def update(self, r_inp=None, v_inp=None):
"""Advance the population by one time step.
Parameters
----------
r_inp : array-like or scalar or None, optional
External input to the rate equation. If ``None``, treated as zero. If
``noise_r`` is set, its output is added. Default is ``None``.
v_inp : array-like or scalar or None, optional
External input to the potential equation (the coupling port). If
``None``, treated as zero. If ``noise_v`` is set, its output is added.
Default is ``None``.
Returns
-------
array-like
The updated firing rate ``r`` (the coupling observable), same shape as
the internal state.
"""
r_inp = 0.0 if r_inp is None else r_inp
if self.noise_r is not None:
r_inp = r_inp + self.noise_r()
v_inp = 0.0 if v_inp is None else v_inp
if self.noise_v is not None:
v_inp = v_inp + self.noise_v()
r, v = self._solve_step(
exp_euler_specs=(
(self.dr, self.r.value, self.v.value, r_inp),
(self.dv, self.v.value, self.r.value, v_inp),
),
ode_state=(self.r.value, self.v.value),
ode_inputs=(r_inp, v_inp),
)
self.r.value = r
self.v.value = v
return r