Adding Coupling#
This tutorial focuses on coupling mechanisms for connecting brain regions in networks.
See Building Multi-Region Networks for general network setup. This guide covers coupling details.
Coupling Types#
Diffusive Coupling#
Drives nodes toward their neighbors’ states:
import brainmass
import jax.numpy as jnp
W = jnp.ones((10, 10)) * 0.1 # connectivity matrix
W = W.at[jnp.diag_indices(10)].set(0.)
coupling = brainmass.DiffusiveCoupling(conn=W, k=0.5)
coupling.init_all_states()
# Use in network
coupled_input = coupling(source_activity, target_activity)
Additive Coupling#
Weighted sum of neighbor inputs:
coupling = brainmass.AdditiveCoupling(conn=W, k=0.3)
coupled_input = coupling(source_activity)
When to Use Each#
Diffusive: Most common for brain networks; models synchronization
Additive: Direct input summation; simpler but less realistic
Coupling Strength#
The global coupling parameter k controls network integration:
# Weak coupling: independent regions
coupling_weak = brainmass.DiffusiveCoupling(conn=W, k=0.01)
# Moderate coupling: balanced
coupling_mod = brainmass.DiffusiveCoupling(conn=W, k=0.2)
# Strong coupling: synchronized
coupling_strong = brainmass.DiffusiveCoupling(conn=W, k=1.0)
Finding optimal k:
Start with k ~ 0.1
Increase until network shows desired synchronization
Too high → all regions synchronized (unrealistic)
Too low → no functional connectivity
Multi-Modal Coupling#
Couple different state variables:
import brainmass
import brainstate
# Separate coupling for E and I populations
coupling_E = brainmass.DiffusiveCoupling(conn=W_E, k=0.2)
coupling_I = brainmass.DiffusiveCoupling(conn=W_I, k=0.1)
def multi_modal_step(i):
rE = nodes.rE.value
rI = nodes.rI.value
coupled_E = coupling_E(rE, rE)
coupled_I = coupling_I(rI, rI)
nodes.update(rE_inp=coupled_E, rI_inp=coupled_I)
Laplacian Coupling#
Use graph Laplacian for diffusive coupling:
# Compute Laplacian
L = brainmass.laplacian_connectivity(W, normalize=False)
# L[i,i] = sum_j W[i,j], L[i,j] = -W[i,j] for i≠j
# Diffusive coupling via Laplacian
coupled = -k * (L @ activity)
Normalized Laplacian:
L_norm = brainmass.laplacian_connectivity(W, normalize=True)
# Symmetric normalization: D^(-1/2) L D^(-1/2)
Time-Delayed Coupling#
Implement axonal transmission delays:
import jax.numpy as jnp
max_delay = 10 # time steps
N_regions = 90
# Circular buffer for history
history_buffer = jnp.zeros((max_delay, N_regions))
def step_with_delay(i, buffer):
# Current activity
current = nodes.update()
# Apply coupling with oldest buffer value (max delay)
delayed_activity = buffer[0]
coupled = coupling(delayed_activity, current)
nodes.x.value += coupled
# Update buffer
buffer = jnp.roll(buffer, shift=-1, axis=0)
buffer = buffer.at[-1].set(current)
return current, buffer
# Run with scan for stateful buffer
def scan_fn(buffer, i):
output, new_buffer = step_with_delay(i, buffer)
return new_buffer, output
final_buffer, outputs = jax.lax.scan(
scan_fn,
history_buffer,
jnp.arange(1000)
)
Custom Coupling Functions#
Implement custom coupling rules:
def nonlinear_coupling(source, target, conn, k, threshold=0.5):
"""Only couple if source activity exceeds threshold"""
active_source = jnp.where(source > threshold, source, 0.)
return k * (conn.T @ active_source)
# Use in network
def custom_network_step(i):
activity = nodes.x.value
coupled = nonlinear_coupling(activity, activity, W, k=0.2, threshold=0.3)
nodes.update()
nodes.x.value += coupled
Directional Coupling#
Different coupling for different directions:
# Asymmetric connectivity
W_forward = jnp.load('feedforward_connections.npy')
W_backward = jnp.load('feedback_connections.npy')
coupling_ff = brainmass.AdditiveCoupling(conn=W_forward, k=0.3)
coupling_fb = brainmass.AdditiveCoupling(conn=W_backward, k=0.1)
def directional_step(i):
activity = nodes.x.value
# Feedforward coupling
ff_input = coupling_ff(activity)
# Feedback coupling (different strength)
fb_input = coupling_fb(activity)
# Combine
total_coupling = ff_input + fb_input
nodes.update()
nodes.x.value += total_coupling
Optimization and Performance#
JIT Compilation#
import jax
@jax.jit
def fast_network_step(state_x, state_y):
# Move coupling inside JIT for speed
coupled = coupling(state_x, state_x)
# ... update logic ...
return new_state_x, new_state_y
Sparse Connectivity#
For large sparse networks, consider sparse representations:
# Dense: (N, N) array
# Sparse: Use JAX BCOO (experimental)
# Or threshold small weights
W_sparse = jnp.where(W > 0.01, W, 0.)
Best Practices#
Start Simple: Test with diffusive coupling before trying custom coupling
Normalize SC: Divide by max or row/column sums to prevent instability
Tune k Systematically: Scan coupling strengths to find regime
Monitor FC: Compare simulated functional connectivity to empirical data
Check Units: Ensure coupling input has correct units for the model
Troubleshooting#
Network Explodes:
Reduce coupling strength k
Normalize connectivity matrix
Add damping/noise
No Synchronization:
Increase k
Check connectivity topology (any isolated nodes?)
Ensure sufficient simulation time
Unphysical Dynamics:
Check coupling sign (diffusive should drive toward average)
Verify connectivity matrix orientation (row vs column convention)
Next Steps#
Forward Modeling - Map coupled network to neuroimaging signals
Parameter Fitting - Optimize coupling parameters
Coupling Mechanisms - Full coupling API reference
See Also#
Building Multi-Region Networks - Network construction
Examples - Coupling examples