Adding Coupling#

This tutorial focuses on coupling mechanisms for connecting brain regions in networks.

See Building Multi-Region Networks for general network setup. This guide covers coupling details.

Coupling Types#

Diffusive Coupling#

Drives nodes toward their neighbors’ states:

\[C_i = k \sum_j W_{ij} (x_j - x_i)\]
import brainmass
import jax.numpy as jnp

W = jnp.ones((10, 10)) * 0.1  # connectivity matrix
W = W.at[jnp.diag_indices(10)].set(0.)

coupling = brainmass.DiffusiveCoupling(conn=W, k=0.5)
coupling.init_all_states()

# Use in network
coupled_input = coupling(source_activity, target_activity)

Additive Coupling#

Weighted sum of neighbor inputs:

\[C_i = k \sum_j W_{ij} x_j\]
coupling = brainmass.AdditiveCoupling(conn=W, k=0.3)
coupled_input = coupling(source_activity)

When to Use Each#

  • Diffusive: Most common for brain networks; models synchronization

  • Additive: Direct input summation; simpler but less realistic

Coupling Strength#

The global coupling parameter k controls network integration:

# Weak coupling: independent regions
coupling_weak = brainmass.DiffusiveCoupling(conn=W, k=0.01)

# Moderate coupling: balanced
coupling_mod = brainmass.DiffusiveCoupling(conn=W, k=0.2)

# Strong coupling: synchronized
coupling_strong = brainmass.DiffusiveCoupling(conn=W, k=1.0)

Finding optimal k:

  • Start with k ~ 0.1

  • Increase until network shows desired synchronization

  • Too high → all regions synchronized (unrealistic)

  • Too low → no functional connectivity

Multi-Modal Coupling#

Couple different state variables:

import brainmass
import brainstate

# Separate coupling for E and I populations
coupling_E = brainmass.DiffusiveCoupling(conn=W_E, k=0.2)
coupling_I = brainmass.DiffusiveCoupling(conn=W_I, k=0.1)

def multi_modal_step(i):
    rE = nodes.rE.value
    rI = nodes.rI.value

    coupled_E = coupling_E(rE, rE)
    coupled_I = coupling_I(rI, rI)

    nodes.update(rE_inp=coupled_E, rI_inp=coupled_I)

Laplacian Coupling#

Use graph Laplacian for diffusive coupling:

# Compute Laplacian
L = brainmass.laplacian_connectivity(W, normalize=False)
# L[i,i] = sum_j W[i,j], L[i,j] = -W[i,j] for i≠j

# Diffusive coupling via Laplacian
coupled = -k * (L @ activity)

Normalized Laplacian:

L_norm = brainmass.laplacian_connectivity(W, normalize=True)
# Symmetric normalization: D^(-1/2) L D^(-1/2)

Time-Delayed Coupling#

Implement axonal transmission delays:

import jax.numpy as jnp

max_delay = 10  # time steps
N_regions = 90

# Circular buffer for history
history_buffer = jnp.zeros((max_delay, N_regions))

def step_with_delay(i, buffer):
    # Current activity
    current = nodes.update()

    # Apply coupling with oldest buffer value (max delay)
    delayed_activity = buffer[0]
    coupled = coupling(delayed_activity, current)
    nodes.x.value += coupled

    # Update buffer
    buffer = jnp.roll(buffer, shift=-1, axis=0)
    buffer = buffer.at[-1].set(current)

    return current, buffer

# Run with scan for stateful buffer
def scan_fn(buffer, i):
    output, new_buffer = step_with_delay(i, buffer)
    return new_buffer, output

final_buffer, outputs = jax.lax.scan(
    scan_fn,
    history_buffer,
    jnp.arange(1000)
)

Custom Coupling Functions#

Implement custom coupling rules:

def nonlinear_coupling(source, target, conn, k, threshold=0.5):
    """Only couple if source activity exceeds threshold"""
    active_source = jnp.where(source > threshold, source, 0.)
    return k * (conn.T @ active_source)

# Use in network
def custom_network_step(i):
    activity = nodes.x.value
    coupled = nonlinear_coupling(activity, activity, W, k=0.2, threshold=0.3)
    nodes.update()
    nodes.x.value += coupled

Directional Coupling#

Different coupling for different directions:

# Asymmetric connectivity
W_forward = jnp.load('feedforward_connections.npy')
W_backward = jnp.load('feedback_connections.npy')

coupling_ff = brainmass.AdditiveCoupling(conn=W_forward, k=0.3)
coupling_fb = brainmass.AdditiveCoupling(conn=W_backward, k=0.1)

def directional_step(i):
    activity = nodes.x.value

    # Feedforward coupling
    ff_input = coupling_ff(activity)

    # Feedback coupling (different strength)
    fb_input = coupling_fb(activity)

    # Combine
    total_coupling = ff_input + fb_input
    nodes.update()
    nodes.x.value += total_coupling

Optimization and Performance#

JIT Compilation#

import jax

@jax.jit
def fast_network_step(state_x, state_y):
    # Move coupling inside JIT for speed
    coupled = coupling(state_x, state_x)
    # ... update logic ...
    return new_state_x, new_state_y

Sparse Connectivity#

For large sparse networks, consider sparse representations:

# Dense: (N, N) array
# Sparse: Use JAX BCOO (experimental)
# Or threshold small weights
W_sparse = jnp.where(W > 0.01, W, 0.)

Best Practices#

  1. Start Simple: Test with diffusive coupling before trying custom coupling

  2. Normalize SC: Divide by max or row/column sums to prevent instability

  3. Tune k Systematically: Scan coupling strengths to find regime

  4. Monitor FC: Compare simulated functional connectivity to empirical data

  5. Check Units: Ensure coupling input has correct units for the model

Troubleshooting#

Network Explodes:

  • Reduce coupling strength k

  • Normalize connectivity matrix

  • Add damping/noise

No Synchronization:

  • Increase k

  • Check connectivity topology (any isolated nodes?)

  • Ensure sufficient simulation time

Unphysical Dynamics:

  • Check coupling sign (diffusive should drive toward average)

  • Verify connectivity matrix orientation (row vs column convention)

Next Steps#

See Also#