Utility Functions#

brainmass provides several utility functions for common operations in neural mass modeling.

API Reference#

sys2nd

Second-order system dynamics.

sigmoid

Sigmoidal firing rate function.

bounded_input

Apply tanh bounding to input signal.

process_sequence

Stack a sequence of data items along a new dimension.

sys2nd#

brainmass.sys2nd(A, a, u, x, v)[source]#

Second-order system dynamics.

Implements the derivative of a second-order linear system:

d²x/dt² + 2a·dx/dt + a²·x = A·a·u

Which can be written as:

dv/dt = A·a·u - 2·a·v - a²·x

where v = dx/dt

Parameters:
  • A (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Amplitude gain parameter.

  • a (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Time constant parameter (1/time).

  • u (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input signal.

  • x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Position state (integrated output).

  • v (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Velocity state (derivative of position).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Returns:

dv/dt - the acceleration (derivative of velocity).

Converts a second-order differential equation to a first-order system.

Mathematical Background:

A second-order ODE:

\[\ddot{x} = f(x, \dot{x}, t)\]

can be rewritten as a first-order system:

\[\begin{split}\dot{x} &= v \\ \dot{v} &= f(x, v, t)\end{split}\]

Example:

import brainmass
import jax.numpy as jnp

# Harmonic oscillator: ẍ + ω²x = 0
omega = 2.0

x = 1.0  # position
v = 0.0  # velocity

# Acceleration from second-order equation
acc = -omega**2 * x

# Convert to first-order updates
dx_dt, dv_dt = brainmass.sys2nd(v, acc)

# Euler integration
dt = 0.01
x_new = x + dx_dt * dt
v_new = v + dv_dt * dt

sigmoid#

brainmass.sigmoid(x, vmax, v0, r)[source]#

Sigmoidal firing rate function.

Converts membrane potential to firing rate using a sigmoid function.

S(x) = vmax / (1 + exp(r·(v0 - x)))

Parameters:
  • x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input membrane potential.

  • vmax (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Maximum firing rate.

  • v0 (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Firing threshold (potential at half-max rate).

  • r (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Steepness of the sigmoid.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Returns:

Firing rate in range (0, vmax).

Standard sigmoid (logistic) activation function.

Formula:

\[\sigma(x) = \frac{1}{1 + e^{-x}}\]

Properties: - Range: (0, 1) - Smooth and differentiable everywhere - Used in neural activation, probability mapping

Example:

import brainmass
import jax.numpy as jnp

x = jnp.array([-2, -1, 0, 1, 2])
y = brainmass.sigmoid(x)
# y ≈ [0.119, 0.269, 0.5, 0.731, 0.881]

# Common in firing rate models
def firing_rate(membrane_potential, threshold, gain):
    shifted = gain * (membrane_potential - threshold)
    return brainmass.sigmoid(shifted)

bounded_input#

brainmass.bounded_input(x, bound=500.0)[source]#

Apply tanh bounding to input signal.

Prevents numerical instability by limiting the magnitude of inputs to the second-order system.

Parameters:
  • x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input signal.

  • bound (float) – Maximum absolute value.

Returns:

bound * tanh(u / bound)

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Bounds input values to a specified range, useful for ensuring physiological constraints.

Example:

import brainmass
import jax.numpy as jnp

# Bound firing rates to [0, 100] Hz
rates = jnp.array([-10, 50, 150])
bounded_rates = brainmass.bounded_input(rates, lower=0, upper=100)
# bounded_rates = [0, 50, 100]

# Bound membrane potentials to realistic range
V = jnp.array([-80, -65, -40, 20])  # mV
V_bounded = brainmass.bounded_input(V, lower=-80, upper=0)
# V_bounded = [-80, -65, -40, 0]

process_sequence#

brainmass.process_sequence(data, mode='stack')[source]#

Stack a sequence of data items along a new dimension.

This is the inverse operation of slice_data - while slice_data reduces a dimension via aggregation, stack_data creates a new dimension by stacking multiple items together.

Returns:

  • mode=’stack’: Array/dict/tuple/list with new dimension at dim

  • mode=’last’/’first’: Single item (same as data[-1] or data[0])

  • mode=’avg’/’mean’/’max’/’min’: Aggregated tensor/dict/tuple/list

  • mode=callable: Result of applying callable to stacked data

For dicts/tuples/lists, aggregation is applied recursively to each element.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity | Tuple | List | Dict | Sequence

Raises:
  • ValueError – If data is empty (cannot infer structure/type) or if mode is an unknown string.

  • TypeError – If sequence contains mixed or incompatible types, or if mode is not a string or callable.

Parameters:

Notes

  • All elements in data must have the same type and structure

  • NumPy arrays are automatically converted to float32 tensors

  • Dictionary keys must match across all elements

  • Tuple/list lengths must match across all elements

  • Recursive: handles nested structures (e.g., dict of tuples)

Processes a sequence through a model, handling state management automatically.

Example:

import brainmass
import jax.numpy as jnp
import brainstate

# Create model
model = brainmass.HopfOscillator(in_size=5, omega=10)
model.init_all_states()

# Input sequence
input_sequence = jnp.randn(100, 5)  # (time_steps, in_size)

# Process sequence
outputs = brainmass.process_sequence(model, input_sequence)
# outputs has shape (100, output_dim)

# Equivalent to:
# outputs = brainstate.transform.for_loop(
#     lambda t: model.update(input_sequence[t]),
#     jnp.arange(100)
# )

Common Usage Patterns#

Building Custom Activation Functions#

import brainmass
import jax.numpy as jnp

def threshold_linear(x, threshold=0.0, slope=1.0):
    """Rectified linear with custom threshold"""
    return brainmass.bounded_input(
        slope * (x - threshold),
        lower=0.0,
        upper=jnp.inf
    )

def soft_threshold(x, threshold=0.0, gain=1.0):
    """Smooth threshold using sigmoid"""
    return brainmass.sigmoid(gain * (x - threshold))

Implementing Custom Second-Order Dynamics#

import brainmass
import jax.numpy as jnp
import brainunit as u

class CustomOscillator:
    def __init__(self, omega, damping):
        self.omega = omega
        self.damping = damping
        self.x = 0.0
        self.v = 0.0

    def update(self, external_force):
        # Second-order equation: ẍ + 2ζωẋ + ω²x = F
        acc = external_force - 2 * self.damping * self.omega * self.v \
              - self.omega**2 * self.x

        # Convert to first-order
        dx_dt, dv_dt = brainmass.sys2nd(self.v, acc)

        # Integrate (Euler)
        dt = 0.001
        self.x += dx_dt * dt
        self.v += dv_dt * dt

        return self.x

Bounded Sigmoid for Physiological Ranges#

def physiological_sigmoid(x, x_min, x_max, v_min, v_max):
    """Map input range [x_min, x_max] to output [v_min, v_max]"""
    # Normalize to [0, 1]
    x_norm = (x - x_min) / (x_max - x_min)

    # Apply sigmoid for smoothness
    y_norm = brainmass.sigmoid(10 * (x_norm - 0.5))  # steepness = 10

    # Scale to output range
    return v_min + (v_max - v_min) * y_norm

# Example: membrane potential to firing rate
V = jnp.array([-70, -60, -55, -50, -40])  # mV
rates = physiological_sigmoid(V, x_min=-70, x_max=-40, v_min=0, v_max=100)

Tips and Best Practices#

Numerical Stability: - Use bounded_input to prevent overflow/underflow - Clip gradients for oscillator dynamics - Check for NaN/Inf values in long simulations

Unit Safety: - Utilities work with both unitless arrays and brainunit.Quantity - Ensure consistent units when combining utility outputs with model states

Performance: - These utilities are JIT-compiled by JAX for efficiency - Use them inside jax.jit decorated functions

Debugging: - Use bounded_input to catch out-of-range values during development - process_sequence simplifies debugging sequential models

See Also#

  • Neural Mass Models - Neural mass models that use these utilities

  • types - Type aliases for function signatures

  • JAX documentation for jax.numpy and jax.nn functions