Utility Functions#
brainmass provides several utility functions for common operations in neural mass modeling.
API Reference#
Second-order system dynamics. |
|
Sigmoidal firing rate function. |
|
Apply tanh bounding to input signal. |
|
Stack a sequence of data items along a new dimension. |
sys2nd#
- brainmass.sys2nd(A, a, u, x, v)[source]#
Second-order system dynamics.
- Implements the derivative of a second-order linear system:
d²x/dt² + 2a·dx/dt + a²·x = A·a·u
- Which can be written as:
dv/dt = A·a·u - 2·a·v - a²·x
where v = dx/dt
- Parameters:
A (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Amplitude gain parameter.a (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Time constant parameter (1/time).u (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Input signal.x (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Position state (integrated output).v (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Velocity state (derivative of position).
- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity- Returns:
dv/dt - the acceleration (derivative of velocity).
Converts a second-order differential equation to a first-order system.
Mathematical Background:
A second-order ODE:
can be rewritten as a first-order system:
Example:
import brainmass
import jax.numpy as jnp
# Harmonic oscillator: ẍ + ω²x = 0
omega = 2.0
x = 1.0 # position
v = 0.0 # velocity
# Acceleration from second-order equation
acc = -omega**2 * x
# Convert to first-order updates
dx_dt, dv_dt = brainmass.sys2nd(v, acc)
# Euler integration
dt = 0.01
x_new = x + dx_dt * dt
v_new = v + dv_dt * dt
sigmoid#
- brainmass.sigmoid(x, vmax, v0, r)[source]#
Sigmoidal firing rate function.
Converts membrane potential to firing rate using a sigmoid function.
S(x) = vmax / (1 + exp(r·(v0 - x)))
- Parameters:
x (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Input membrane potential.vmax (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Maximum firing rate.v0 (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Firing threshold (potential at half-max rate).r (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Steepness of the sigmoid.
- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity- Returns:
Firing rate in range (0, vmax).
Standard sigmoid (logistic) activation function.
Formula:
Properties: - Range: (0, 1) - Smooth and differentiable everywhere - Used in neural activation, probability mapping
Example:
import brainmass
import jax.numpy as jnp
x = jnp.array([-2, -1, 0, 1, 2])
y = brainmass.sigmoid(x)
# y ≈ [0.119, 0.269, 0.5, 0.731, 0.881]
# Common in firing rate models
def firing_rate(membrane_potential, threshold, gain):
shifted = gain * (membrane_potential - threshold)
return brainmass.sigmoid(shifted)
bounded_input#
- brainmass.bounded_input(x, bound=500.0)[source]#
Apply tanh bounding to input signal.
Prevents numerical instability by limiting the magnitude of inputs to the second-order system.
Bounds input values to a specified range, useful for ensuring physiological constraints.
Example:
import brainmass
import jax.numpy as jnp
# Bound firing rates to [0, 100] Hz
rates = jnp.array([-10, 50, 150])
bounded_rates = brainmass.bounded_input(rates, lower=0, upper=100)
# bounded_rates = [0, 50, 100]
# Bound membrane potentials to realistic range
V = jnp.array([-80, -65, -40, 20]) # mV
V_bounded = brainmass.bounded_input(V, lower=-80, upper=0)
# V_bounded = [-80, -65, -40, 0]
process_sequence#
- brainmass.process_sequence(data, mode='stack')[source]#
Stack a sequence of data items along a new dimension.
This is the inverse operation of slice_data - while slice_data reduces a dimension via aggregation, stack_data creates a new dimension by stacking multiple items together.
- Returns:
mode=’stack’: Array/dict/tuple/list with new dimension at dim
mode=’last’/’first’: Single item (same as data[-1] or data[0])
mode=’avg’/’mean’/’max’/’min’: Aggregated tensor/dict/tuple/list
mode=callable: Result of applying callable to stacked data
For dicts/tuples/lists, aggregation is applied recursively to each element.
- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Tuple|List|Dict|Sequence- Raises:
ValueError – If data is empty (cannot infer structure/type) or if mode is an unknown string.
TypeError – If sequence contains mixed or incompatible types, or if mode is not a string or callable.
- Parameters:
Notes
All elements in data must have the same type and structure
NumPy arrays are automatically converted to float32 tensors
Dictionary keys must match across all elements
Tuple/list lengths must match across all elements
Recursive: handles nested structures (e.g., dict of tuples)
Processes a sequence through a model, handling state management automatically.
Example:
import brainmass
import jax.numpy as jnp
import brainstate
# Create model
model = brainmass.HopfOscillator(in_size=5, omega=10)
model.init_all_states()
# Input sequence
input_sequence = jnp.randn(100, 5) # (time_steps, in_size)
# Process sequence
outputs = brainmass.process_sequence(model, input_sequence)
# outputs has shape (100, output_dim)
# Equivalent to:
# outputs = brainstate.transform.for_loop(
# lambda t: model.update(input_sequence[t]),
# jnp.arange(100)
# )
Common Usage Patterns#
Building Custom Activation Functions#
import brainmass
import jax.numpy as jnp
def threshold_linear(x, threshold=0.0, slope=1.0):
"""Rectified linear with custom threshold"""
return brainmass.bounded_input(
slope * (x - threshold),
lower=0.0,
upper=jnp.inf
)
def soft_threshold(x, threshold=0.0, gain=1.0):
"""Smooth threshold using sigmoid"""
return brainmass.sigmoid(gain * (x - threshold))
Implementing Custom Second-Order Dynamics#
import brainmass
import jax.numpy as jnp
import brainunit as u
class CustomOscillator:
def __init__(self, omega, damping):
self.omega = omega
self.damping = damping
self.x = 0.0
self.v = 0.0
def update(self, external_force):
# Second-order equation: ẍ + 2ζωẋ + ω²x = F
acc = external_force - 2 * self.damping * self.omega * self.v \
- self.omega**2 * self.x
# Convert to first-order
dx_dt, dv_dt = brainmass.sys2nd(self.v, acc)
# Integrate (Euler)
dt = 0.001
self.x += dx_dt * dt
self.v += dv_dt * dt
return self.x
Bounded Sigmoid for Physiological Ranges#
def physiological_sigmoid(x, x_min, x_max, v_min, v_max):
"""Map input range [x_min, x_max] to output [v_min, v_max]"""
# Normalize to [0, 1]
x_norm = (x - x_min) / (x_max - x_min)
# Apply sigmoid for smoothness
y_norm = brainmass.sigmoid(10 * (x_norm - 0.5)) # steepness = 10
# Scale to output range
return v_min + (v_max - v_min) * y_norm
# Example: membrane potential to firing rate
V = jnp.array([-70, -60, -55, -50, -40]) # mV
rates = physiological_sigmoid(V, x_min=-70, x_max=-40, v_min=0, v_max=100)
Tips and Best Practices#
Numerical Stability:
- Use bounded_input to prevent overflow/underflow
- Clip gradients for oscillator dynamics
- Check for NaN/Inf values in long simulations
Unit Safety:
- Utilities work with both unitless arrays and brainunit.Quantity
- Ensure consistent units when combining utility outputs with model states
Performance:
- These utilities are JIT-compiled by JAX for efficiency
- Use them inside jax.jit decorated functions
Debugging:
- Use bounded_input to catch out-of-range values during development
- process_sequence simplifies debugging sequential models
See Also#
Neural Mass Models - Neural mass models that use these utilities
types - Type aliases for function signatures
JAX documentation for
jax.numpyandjax.nnfunctions