Building Multi-Region Networks#
This tutorial covers creating and simulating large-scale brain networks with multiple regions.
Basic Network Setup#
A brain network consists of:
Nodes: Neural mass models representing brain regions
Edges: Structural connectivity between regions
Coupling: Mechanism for inter-regional communication
Simple Network Example#
import brainmass
import jax.numpy as jnp
import brainunit as u
import brainstate
N_regions = 10
# 1. Create node dynamics (uncoupled)
nodes = brainmass.HopfOscillator(
in_size=N_regions,
omega=2 * jnp.pi * 10 * u.Hz,
a=0.1,
)
# 2. Create structural connectivity
W = jnp.ones((N_regions, N_regions)) * 0.05
W = W.at[jnp.diag_indices(N_regions)].set(0.) # no self-connections
# 3. Create coupling
coupling = brainmass.DiffusiveCoupling(conn=W, k=0.2)
# 4. Initialize
nodes.init_all_states()
coupling.init_all_states()
# 5. Simulation loop
def network_step(i):
x = nodes.x.value
coupled_input = coupling(x, x)
output = nodes.update()
nodes.x.value += coupled_input
return output
network_activity = brainstate.transform.for_loop(
network_step,
jnp.arange(1000)
)
Structural Connectivity#
Loading from DTI#
import jax.numpy as jnp
# Load structural connectivity from DTI tractography
# Common formats: .npy, .mat, .txt
SC = jnp.load('structural_connectivity.npy') # shape (N, N)
# Typically: SC[i, j] = fiber density from j → i
# Normalize (common preprocessing)
SC_norm = SC / SC.sum(axis=0, keepdims=True) # column normalization
# Or row normalization
SC_norm = SC / SC.sum(axis=1, keepdims=True)
Creating Synthetic Networks#
Random Network:
import jax
key = jax.random.PRNGKey(0)
N = 90
# Random weights
W_random = jax.random.uniform(key, (N, N)) * 0.1
W_random = W_random.at[jnp.diag_indices(N)].set(0.)
Small-World Network:
# Simplified small-world (ring + random shortcuts)
N = 90
k = 4 # nearest neighbors
# Ring lattice
W = jnp.zeros((N, N))
for i in range(N):
for j in range(1, k//2 + 1):
W = W.at[i, (i+j) % N].set(0.1)
W = W.at[i, (i-j) % N].set(0.1)
# Add random shortcuts (rewiring)
p_rewire = 0.1
key = jax.random.PRNGKey(42)
# ... rewiring logic ...
Hub Network:
N = 90
N_hubs = 5
W = jnp.zeros((N, N))
# Hubs connect to all
W = W.at[:N_hubs, :].set(0.2)
W = W.at[:, :N_hubs].set(0.2)
# Remove self-connections
W = W.at[jnp.diag_indices(N)].set(0.)
Realistic Brain Networks#
Using Anatomical Atlases#
# Example: AAL90 atlas
N_AAL90 = 90
# Load AAL90 connectivity
SC_AAL90 = jnp.load('AAL90_SC.npy') # from DTI
# Create network
nodes = brainmass.WilsonCowanModel(in_size=N_AAL90)
coupling = brainmass.DiffusiveCoupling(conn=SC_AAL90, k=0.1)
Common Atlases:
AAL (Automated Anatomical Labeling): 90/116 regions
Desikan-Killiany: 68 cortical regions
Destrieux: 148 cortical regions
Schaefer: 100/200/400/etc. parcels
Distance-Dependent Delays#
For large-scale networks, account for axonal conduction delays:
# Distance matrix (mm)
distances = jnp.load('region_distances.npy') # shape (N, N)
# Conduction velocity (m/s)
velocity = 6.0 # typical: 3-9 m/s
# Compute delays (ms)
delays_ms = (distances / velocity).astype(int) # in time steps
# Implement with circular buffer (simplified)
max_delay = delays_ms.max()
history = jnp.zeros((max_delay, N_regions))
def step_with_delay(i, hist):
# Get delayed activity for each connection
# ... (implementation depends on delay structure) ...
pass
Heterogeneous Networks#
Different Models per Region#
# Thalamus: fast oscillators
N_thal = 10
thalamus = brainmass.HopfOscillator(
in_size=N_thal,
omega=2 * jnp.pi * 40 * u.Hz, # 40 Hz
)
# Cortex: excitatory-inhibitory dynamics
N_cort = 80
cortex = brainmass.WilsonCowanModel(in_size=N_cort)
# Coupling between subsystems
W_thal_cort = jnp.ones((N_cort, N_thal)) * 0.1 # thalamus → cortex
W_cort_thal = jnp.ones((N_thal, N_cort)) * 0.05 # cortex → thalamus
def hetero_network_step(i):
# Thalamus dynamics
thal_out = thalamus.update()
# Cortex receives thalamic input
thal_drive = (W_thal_cort @ thal_out).mean()
cort_out = cortex.update(rE_inp=thal_drive, rI_inp=0.)
# Thalamus receives cortical feedback
cort_feedback = (W_cort_thal @ cortex.rE.value).mean()
thalamus.x.value += cort_feedback * 0.1
return cort_out
Region-Specific Parameters#
# Different parameters for each region
N = 90
# Example: heterogeneous excitability
a_values = jax.random.uniform(jax.random.PRNGKey(0), (N,)) * 0.2 # 0-0.2 range
# Manually apply per-region (requires custom implementation)
# Or use batched models with different parameters
Network Analysis#
Computing Functional Connectivity#
# Simulate network
activity = brainstate.transform.for_loop(network_step, jnp.arange(10000))
# Compute FC (Pearson correlation)
activity_centered = activity - activity.mean(axis=0)
FC = jnp.corrcoef(activity_centered.T) # shape (N, N)
# Visualize
import matplotlib.pyplot as plt
plt.imshow(FC, cmap='RdBu_r', vmin=-1, vmax=1)
plt.colorbar()
plt.title('Functional Connectivity')
Network Synchrony#
# Kuramoto order parameter
def kuramoto_order(phases):
"""Measure global synchronization"""
z = jnp.mean(jnp.exp(1j * phases))
return jnp.abs(z) # R ∈ [0, 1]
# For Kuramoto network
kuramoto_net = brainmass.KuramotoNetwork(in_size=100, omega_mean=10*u.Hz)
kuramoto_net.init_all_states()
order_params = []
for i in range(1000):
phases = kuramoto_net.theta.value
R = kuramoto_order(phases)
order_params.append(R)
kuramoto_net.update()
Complete Network Example#
Whole-Brain Resting-State Simulation#
import brainmass
import jax.numpy as jnp
import brainunit as u
import brainstate
# Parameters
N_regions = 90 # AAL90 atlas
coupling_strength = 0.2
simulation_time = 600 # seconds
dt = 1 * u.ms
T_steps = int((simulation_time * u.second / dt).magnitude)
# Load structural connectivity
SC = jnp.load('AAL90_SC_normalized.npy')
# Create components
nodes = brainmass.WongWangModel(in_size=N_regions)
coupling = brainmass.DiffusiveCoupling(conn=SC, k=coupling_strength)
# Add noise for spontaneous activity
nodes.noise_E = brainmass.OUProcess(
in_size=N_regions,
sigma=0.01 * u.Hz,
tau=100 * u.ms,
)
# BOLD forward model
bold = brainmass.BOLDSignal(in_size=N_regions)
# Initialize
nodes.init_all_states()
coupling.init_all_states()
bold.init_all_states()
# Simulate
print("Running simulation...")
neural_activity = []
for t in range(T_steps):
# Get synaptic activity
S_E = nodes.S_E.value
# Apply coupling
coupled_input = coupling(S_E, S_E)
# Update nodes
output = nodes.update(S_E_ext=coupled_input)
neural_activity.append(output)
if t % 10000 == 0:
print(f"Step {t}/{T_steps}")
neural_activity = jnp.stack(neural_activity)
# Generate BOLD
print("Generating BOLD signal...")
bold_ts = []
for z in neural_activity:
bold.update(z=z)
bold_ts.append(bold.bold())
bold_ts = jnp.stack(bold_ts)
# Downsample to TR = 2s
TR_steps = int((2 * u.second / dt).magnitude)
bold_downsampled = bold_ts[::TR_steps]
# Compute FC
FC_sim = jnp.corrcoef(bold_downsampled.T)
print(f"Simulated FC shape: {FC_sim.shape}")
Best Practices#
Start Small: Test with N=10-20 regions before scaling to N=90+
Normalize Connectivity: Prevent unstable dynamics from unnormalized SC
Monitor Dynamics: Plot time series to check for explosions/collapse
Use Noise: Spontaneous fluctuations prevent fixed points
Check Timescales: Match dt to fastest dynamics in the network
Profile Performance: Use JAX profiling for large networks
Common Issues#
Exploding Activity:
Reduce coupling strength k
Normalize connectivity matrix
Check for positive feedback loops
No Synchronization:
Increase coupling strength
Check connectivity topology
Ensure sufficient simulation time
Slow Simulation:
Use JIT compilation: @jax.jit
Reduce number of regions for testing
Use simpler models (Hopf vs Jansen-Rit)
Next Steps#
Adding Coupling - Advanced coupling mechanisms
Forward Modeling - Map network activity to neuroimaging signals
Parameter Fitting - Optimize network parameters
See Also#
Coupling Mechanisms - Coupling API reference
Neural Mass Models - Node model options
Examples - Network simulation examples