Utilities & Types#
brainmass provides several utility functions for common operations in neural mass
modeling, plus a handful of type aliases used throughout the public signatures.
API Reference#
Compute the acceleration of a second-order linear system. |
|
Convert membrane potential to firing rate via a sigmoid. |
|
Apply a |
|
Aggregate a sequence of data items along the leading dimension. |
|
Build the neuron-index matrix used to address per-connection delays. |
delay_index() converts a delay (in time units) into an integer buffer index for
the circular delay buffers that delayed coupling reads from.
Type Aliases#
These aliases annotate the public model / coupling signatures. They are ordinary
typing constructs (not classes), so pass any value matching the alias.
- brainmass.Initializer#
Union[Callable, brainstate.typing.ArrayLike]– a parameter/state initializer: either an array-like value or a callableshape -> array(e.g. abraintools.initinitializer).
- brainmass.Array#
brainstate.typing.ArrayLike– any array-like input (NumPy / JAX array, scalar, or a unit-carryingbrainunit.Quantity).
- brainmass.Parameter#
Union[Callable, brainstate.typing.ArrayLike, brainstate.nn.Param]– a model parameter: an array-like value, an initializer callable, or a wrappedParam(constrainable / trainable).
Model discovery#
Return the catalogue of public neural-mass models. |
|
|
A typed catalogue record describing one public model. |
list_models() returns a list of typed ModelInfo records – one per
public neural-mass model – giving each model’s name, category
(phenomenological / physiological / network), number of state variables, and a
one-line typical use case. It powers the howto/choose_a_model guide and is
pleasant to scan in the REPL via brainmass.list_models.to_table().
>>> import brainmass
>>> models = brainmass.list_models()
>>> {m.name for m in models} >= {'HopfStep', 'JansenRitStep'}
True
>>> next(m for m in models if m.name == 'JansenRitStep').category
'physiological'
>>> print(brainmass.list_models.to_table())
name...category...#states...use_case...
sys2nd#
- brainmass.sys2nd(A, a, u, x, v)[source]#
Compute the acceleration of a second-order linear system.
Implements the canonical second-order kinetic block used by neural-mass models (e.g. Jansen-Rit). The system
\[\frac{d^2 x}{dt^2} + 2 a \frac{dx}{dt} + a^2 x = A\,a\,u\]is written in state-space form with
v = dx/dtas\[\frac{dv}{dt} = A\,a\,u - 2 a v - a^2 x .\]- Parameters:
A (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Amplitude gain parameter.a (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Time-constant parameter (units of inverse time).u (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Input signal.x (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Position state (the integrated output).v (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Velocity state (the derivative ofx).
- Returns:
dv/dt— the acceleration, i.e. the derivative of the velocity state.- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity
Converts a second-order differential equation to a first-order system.
Mathematical Background:
A second-order ODE:
can be rewritten as a first-order system:
Example:
import brainmass
import jax.numpy as jnp
# Harmonic oscillator: ẍ + ω²x = 0
omega = 2.0
x = 1.0 # position
v = 0.0 # velocity
# Acceleration from second-order equation
acc = -omega**2 * x
# Convert to first-order updates
dx_dt, dv_dt = brainmass.sys2nd(v, acc)
# Euler integration
dt = 0.01
x_new = x + dx_dt * dt
v_new = v + dv_dt * dt
sigmoid#
- brainmass.sigmoid(x, vmax, v0, r)[source]#
Convert membrane potential to firing rate via a sigmoid.
\[S(x) = \frac{v_{max}}{1 + \exp\!\big(r (v_0 - x)\big)}\]- Parameters:
x (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Input membrane potential.vmax (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Maximum firing rate.v0 (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Firing threshold (potential at half-maximum rate).r (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Steepness of the sigmoid.
- Returns:
Firing rate in the open interval
(0, vmax).- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity
Standard sigmoid (logistic) activation function.
Formula:
Properties: - Range: (0, 1) - Smooth and differentiable everywhere - Used in neural activation, probability mapping
Example:
import brainmass
import jax.numpy as jnp
x = jnp.array([-2, -1, 0, 1, 2])
y = brainmass.sigmoid(x)
# y ≈ [0.119, 0.269, 0.5, 0.731, 0.881]
# Common in firing rate models
def firing_rate(membrane_potential, threshold, gain):
shifted = gain * (membrane_potential - threshold)
return brainmass.sigmoid(shifted)
bounded_input#
- brainmass.bounded_input(x, bound=500.0)[source]#
Apply a
tanhbound to an input signal.Prevents numerical instability by smoothly limiting the magnitude of the inputs fed to a second-order system.
Bounds input values to a specified range, useful for ensuring physiological constraints.
Example:
import brainmass
import jax.numpy as jnp
# Bound firing rates to [0, 100] Hz
rates = jnp.array([-10, 50, 150])
bounded_rates = brainmass.bounded_input(rates, lower=0, upper=100)
# bounded_rates = [0, 50, 100]
# Bound membrane potentials to realistic range
V = jnp.array([-80, -65, -40, 20]) # mV
V_bounded = brainmass.bounded_input(V, lower=-80, upper=0)
# V_bounded = [-80, -65, -40, 0]
process_sequence#
- brainmass.process_sequence(data, mode='stack')[source]#
Aggregate a sequence of data items along the leading dimension.
The sequence is first stacked along a new leading axis and then reduced according to
mode. The reduction is applied recursively over arbitrary PyTrees (dicts, tuples, lists, and nested combinations thereof).- Parameters:
data (
PyTree) – A stacked PyTree of items, with the items enumerated along the leading axis of every leaf. All leaves must share a compatible structure.How to reduce along the leading axis:
'stack': return the data unchanged (identity).'last'/'first': take the last/first item.'avg'/'mean': mean over the leading axis.'max'/'min': max/min over the leading axis.callable : apply the callable to each leaf.
- Returns:
Aggregated data whose structure matches the input leaves:
mode='stack': structure with the new leading dimension retained.mode='last'/'first': a single item.mode='avg'/'mean'/'max'/'min': reduced PyTree.modecallable : the result of applying the callable.
- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Tuple|List|Dict|Sequence- Raises:
ValueError – If
modeis an unknown string.
Notes
All leaves must share the same type and structure.
Dictionary keys (and tuple/list lengths) must match across elements.
The reduction is recursive and handles nested structures, e.g. a dict of tuples of arrays.
Processes a sequence through a model, handling state management automatically.
Example:
import brainmass
import jax.numpy as jnp
import brainstate
# Create model
model = brainmass.HopfStep(in_size=5, w=0.2)
model.init_all_states()
# Input sequence
input_sequence = jnp.randn(100, 5) # (time_steps, in_size)
# Process sequence
outputs = brainmass.process_sequence(model, input_sequence)
# outputs has shape (100, output_dim)
# Equivalent to:
# outputs = brainstate.transform.for_loop(
# lambda t: model.update(input_sequence[t]),
# jnp.arange(100)
# )
Common Usage Patterns#
Building Custom Activation Functions#
import brainmass
import jax.numpy as jnp
def threshold_linear(x, threshold=0.0, slope=1.0):
"""Rectified linear with custom threshold"""
return brainmass.bounded_input(
slope * (x - threshold),
lower=0.0,
upper=jnp.inf
)
def soft_threshold(x, threshold=0.0, gain=1.0):
"""Smooth threshold using sigmoid"""
return brainmass.sigmoid(gain * (x - threshold))
Implementing Custom Second-Order Dynamics#
import brainmass
import jax.numpy as jnp
import brainunit as u
class CustomOscillator:
def __init__(self, omega, damping):
self.omega = omega
self.damping = damping
self.x = 0.0
self.v = 0.0
def update(self, external_force):
# Second-order equation: ẍ + 2ζωẋ + ω²x = F
acc = external_force - 2 * self.damping * self.omega * self.v \
- self.omega**2 * self.x
# Convert to first-order
dx_dt, dv_dt = brainmass.sys2nd(self.v, acc)
# Integrate (Euler)
dt = 0.001
self.x += dx_dt * dt
self.v += dv_dt * dt
return self.x
Bounded Sigmoid for Physiological Ranges#
def physiological_sigmoid(x, x_min, x_max, v_min, v_max):
"""Map input range [x_min, x_max] to output [v_min, v_max]"""
# Normalize to [0, 1]
x_norm = (x - x_min) / (x_max - x_min)
# Apply sigmoid for smoothness
y_norm = brainmass.sigmoid(10 * (x_norm - 0.5)) # steepness = 10
# Scale to output range
return v_min + (v_max - v_min) * y_norm
# Example: membrane potential to firing rate
V = jnp.array([-70, -60, -55, -50, -40]) # mV
rates = physiological_sigmoid(V, x_min=-70, x_max=-40, v_min=0, v_max=100)
Tips and Best Practices#
Numerical Stability:
- Use bounded_input to prevent overflow/underflow
- Clip gradients for oscillator dynamics
- Check for NaN/Inf values in long simulations
Unit Safety:
- Utilities work with both unitless arrays and brainunit.Quantity
- Ensure consistent units when combining utility outputs with model states
Performance:
- These utilities are JIT-compiled by JAX for efficiency
- Use them inside jax.jit decorated functions
Debugging:
- Use bounded_input to catch out-of-range values during development
- process_sequence simplifies debugging sequential models
See Also#
Neural Mass Models - Neural mass models that use these utilities
JAX documentation for
jax.numpyandjax.nnfunctions