Utilities & Types#

brainmass provides several utility functions for common operations in neural mass modeling, plus a handful of type aliases used throughout the public signatures.

API Reference#

sys2nd

Compute the acceleration of a second-order linear system.

sigmoid

Convert membrane potential to firing rate via a sigmoid.

bounded_input

Apply a tanh bound to an input signal.

process_sequence

Aggregate a sequence of data items along the leading dimension.

delay_index

Build the neuron-index matrix used to address per-connection delays.

delay_index() converts a delay (in time units) into an integer buffer index for the circular delay buffers that delayed coupling reads from.

Type Aliases#

These aliases annotate the public model / coupling signatures. They are ordinary typing constructs (not classes), so pass any value matching the alias.

brainmass.Initializer#

Union[Callable, brainstate.typing.ArrayLike] – a parameter/state initializer: either an array-like value or a callable shape -> array (e.g. a braintools.init initializer).

brainmass.Array#

brainstate.typing.ArrayLike – any array-like input (NumPy / JAX array, scalar, or a unit-carrying brainunit.Quantity).

brainmass.Parameter#

Union[Callable, brainstate.typing.ArrayLike, brainstate.nn.Param] – a model parameter: an array-like value, an initializer callable, or a wrapped Param (constrainable / trainable).

Model discovery#

list_models()

Return the catalogue of public neural-mass models.

ModelInfo(name, category, n_state_vars, use_case)

A typed catalogue record describing one public model.

list_models() returns a list of typed ModelInfo records – one per public neural-mass model – giving each model’s name, category (phenomenological / physiological / network), number of state variables, and a one-line typical use case. It powers the howto/choose_a_model guide and is pleasant to scan in the REPL via brainmass.list_models.to_table().

>>> import brainmass
>>> models = brainmass.list_models()
>>> {m.name for m in models} >= {'HopfStep', 'JansenRitStep'}
True
>>> next(m for m in models if m.name == 'JansenRitStep').category
'physiological'
>>> print(brainmass.list_models.to_table())
name...category...#states...use_case...

sys2nd#

brainmass.sys2nd(A, a, u, x, v)[source]#

Compute the acceleration of a second-order linear system.

Implements the canonical second-order kinetic block used by neural-mass models (e.g. Jansen-Rit). The system

\[\frac{d^2 x}{dt^2} + 2 a \frac{dx}{dt} + a^2 x = A\,a\,u\]

is written in state-space form with v = dx/dt as

\[\frac{dv}{dt} = A\,a\,u - 2 a v - a^2 x .\]
Parameters:
  • A (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Amplitude gain parameter.

  • a (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Time-constant parameter (units of inverse time).

  • u (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input signal.

  • x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Position state (the integrated output).

  • v (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Velocity state (the derivative of x).

Returns:

dv/dt — the acceleration, i.e. the derivative of the velocity state.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Converts a second-order differential equation to a first-order system.

Mathematical Background:

A second-order ODE:

\[\ddot{x} = f(x, \dot{x}, t)\]

can be rewritten as a first-order system:

\[\begin{split}\dot{x} &= v \\ \dot{v} &= f(x, v, t)\end{split}\]

Example:

import brainmass
import jax.numpy as jnp

# Harmonic oscillator: ẍ + ω²x = 0
omega = 2.0

x = 1.0  # position
v = 0.0  # velocity

# Acceleration from second-order equation
acc = -omega**2 * x

# Convert to first-order updates
dx_dt, dv_dt = brainmass.sys2nd(v, acc)

# Euler integration
dt = 0.01
x_new = x + dx_dt * dt
v_new = v + dv_dt * dt

sigmoid#

brainmass.sigmoid(x, vmax, v0, r)[source]#

Convert membrane potential to firing rate via a sigmoid.

\[S(x) = \frac{v_{max}}{1 + \exp\!\big(r (v_0 - x)\big)}\]
Parameters:
  • x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input membrane potential.

  • vmax (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Maximum firing rate.

  • v0 (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Firing threshold (potential at half-maximum rate).

  • r (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Steepness of the sigmoid.

Returns:

Firing rate in the open interval (0, vmax).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Standard sigmoid (logistic) activation function.

Formula:

\[\sigma(x) = \frac{1}{1 + e^{-x}}\]

Properties: - Range: (0, 1) - Smooth and differentiable everywhere - Used in neural activation, probability mapping

Example:

import brainmass
import jax.numpy as jnp

x = jnp.array([-2, -1, 0, 1, 2])
y = brainmass.sigmoid(x)
# y ≈ [0.119, 0.269, 0.5, 0.731, 0.881]

# Common in firing rate models
def firing_rate(membrane_potential, threshold, gain):
    shifted = gain * (membrane_potential - threshold)
    return brainmass.sigmoid(shifted)

bounded_input#

brainmass.bounded_input(x, bound=500.0)[source]#

Apply a tanh bound to an input signal.

Prevents numerical instability by smoothly limiting the magnitude of the inputs fed to a second-order system.

Parameters:
  • x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input signal.

  • bound (float) – Maximum absolute value of the output.

Returns:

The bounded input bound * tanh(x / bound).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Bounds input values to a specified range, useful for ensuring physiological constraints.

Example:

import brainmass
import jax.numpy as jnp

# Bound firing rates to [0, 100] Hz
rates = jnp.array([-10, 50, 150])
bounded_rates = brainmass.bounded_input(rates, lower=0, upper=100)
# bounded_rates = [0, 50, 100]

# Bound membrane potentials to realistic range
V = jnp.array([-80, -65, -40, 20])  # mV
V_bounded = brainmass.bounded_input(V, lower=-80, upper=0)
# V_bounded = [-80, -65, -40, 0]

process_sequence#

brainmass.process_sequence(data, mode='stack')[source]#

Aggregate a sequence of data items along the leading dimension.

The sequence is first stacked along a new leading axis and then reduced according to mode. The reduction is applied recursively over arbitrary PyTrees (dicts, tuples, lists, and nested combinations thereof).

Parameters:
  • data (PyTree) – A stacked PyTree of items, with the items enumerated along the leading axis of every leaf. All leaves must share a compatible structure.

  • mode (str | Callable) –

    How to reduce along the leading axis:

    • 'stack' : return the data unchanged (identity).

    • 'last' / 'first' : take the last/first item.

    • 'avg' / 'mean' : mean over the leading axis.

    • 'max' / 'min' : max/min over the leading axis.

    • callable : apply the callable to each leaf.

Returns:

Aggregated data whose structure matches the input leaves:

  • mode='stack' : structure with the new leading dimension retained.

  • mode='last' / 'first' : a single item.

  • mode='avg'/'mean'/'max'/'min' : reduced PyTree.

  • mode callable : the result of applying the callable.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity | Tuple | List | Dict | Sequence

Raises:

ValueError – If mode is an unknown string.

Notes

  • All leaves must share the same type and structure.

  • Dictionary keys (and tuple/list lengths) must match across elements.

  • The reduction is recursive and handles nested structures, e.g. a dict of tuples of arrays.

Processes a sequence through a model, handling state management automatically.

Example:

import brainmass
import jax.numpy as jnp
import brainstate

# Create model
model = brainmass.HopfStep(in_size=5, w=0.2)
model.init_all_states()

# Input sequence
input_sequence = jnp.randn(100, 5)  # (time_steps, in_size)

# Process sequence
outputs = brainmass.process_sequence(model, input_sequence)
# outputs has shape (100, output_dim)

# Equivalent to:
# outputs = brainstate.transform.for_loop(
#     lambda t: model.update(input_sequence[t]),
#     jnp.arange(100)
# )

Common Usage Patterns#

Building Custom Activation Functions#

import brainmass
import jax.numpy as jnp

def threshold_linear(x, threshold=0.0, slope=1.0):
    """Rectified linear with custom threshold"""
    return brainmass.bounded_input(
        slope * (x - threshold),
        lower=0.0,
        upper=jnp.inf
    )

def soft_threshold(x, threshold=0.0, gain=1.0):
    """Smooth threshold using sigmoid"""
    return brainmass.sigmoid(gain * (x - threshold))

Implementing Custom Second-Order Dynamics#

import brainmass
import jax.numpy as jnp
import brainunit as u

class CustomOscillator:
    def __init__(self, omega, damping):
        self.omega = omega
        self.damping = damping
        self.x = 0.0
        self.v = 0.0

    def update(self, external_force):
        # Second-order equation: ẍ + 2ζωẋ + ω²x = F
        acc = external_force - 2 * self.damping * self.omega * self.v \
              - self.omega**2 * self.x

        # Convert to first-order
        dx_dt, dv_dt = brainmass.sys2nd(self.v, acc)

        # Integrate (Euler)
        dt = 0.001
        self.x += dx_dt * dt
        self.v += dv_dt * dt

        return self.x

Bounded Sigmoid for Physiological Ranges#

def physiological_sigmoid(x, x_min, x_max, v_min, v_max):
    """Map input range [x_min, x_max] to output [v_min, v_max]"""
    # Normalize to [0, 1]
    x_norm = (x - x_min) / (x_max - x_min)

    # Apply sigmoid for smoothness
    y_norm = brainmass.sigmoid(10 * (x_norm - 0.5))  # steepness = 10

    # Scale to output range
    return v_min + (v_max - v_min) * y_norm

# Example: membrane potential to firing rate
V = jnp.array([-70, -60, -55, -50, -40])  # mV
rates = physiological_sigmoid(V, x_min=-70, x_max=-40, v_min=0, v_max=100)

Tips and Best Practices#

Numerical Stability: - Use bounded_input to prevent overflow/underflow - Clip gradients for oscillator dynamics - Check for NaN/Inf values in long simulations

Unit Safety: - Utilities work with both unitless arrays and brainunit.Quantity - Ensure consistent units when combining utility outputs with model states

Performance: - These utilities are JIT-compiled by JAX for efficiency - Use them inside jax.jit decorated functions

Debugging: - Use bounded_input to catch out-of-range values during development - process_sequence simplifies debugging sequential models

See Also#

  • Neural Mass Models - Neural mass models that use these utilities

  • JAX documentation for jax.numpy and jax.nn functions