Parameter Optimization with Scipy#
scipy has provided many excellent optimization algorithms for decades. Here we show how to use them to fit a whole-brain network model to empirical functional connectivity (FC).
This notebook demonstrates gradient-based (or gradient-free) parameter optimization with SciPy to fit a whole-brain network to empirical functional connectivity (FC).
Goal: tune global coupling k and noise sigma of a Wilson–Cowan network so simulated FC matches a target FC.
Loss: 1 - corr(FC_target, FC_model).
We JIT-compile the loss for speed (optional) and call a SciPy optimizer.
import brainstate
import brainunit as u
import jax.numpy as jnp
import numpy as np
import brainmass
import braintools
brainstate.environ.set(dt=0.1 * u.ms)
import os.path
import kagglehub
path = kagglehub.dataset_download("oujago/hcp-gw-data-samples")
data = braintools.file.msgpack_load(os.path.join(path, "hcp-data-sample.msgpack"))
target_fc = [braintools.metric.functional_connectivity(x.T) for x in data['BOLDs']]
target_fc = jnp.mean(jnp.asarray(target_fc), axis=0)
Loading checkpoint from D:\Data\kagglehub\datasets\oujago\hcp-gw-data-samples\versions\1\hcp-data-sample.msgpack
Data and Target FC#
We download a small HCP sample via kagglehub, providing structural connectivity (Cmat) and distances (Dmat). For each BOLD time series, we compute FC and then average across scans to obtain target_fc.
Cmat: weights used for coupling.
Dmat: distances converted to delays using a signal speed.
target_fc: average empirical FC used by the loss.
If fetching fails, point to a local msgpack file, or replace with your own FC target.
class Network(brainstate.nn.Module):
def __init__(self, signal_speed=2., k=1., sigma=0.01):
super().__init__()
conn_weight = data['Cmat'].copy()
np.fill_diagonal(conn_weight, 0)
delay_time = data['Dmat'].copy() / signal_speed
np.fill_diagonal(delay_time, 0)
indices_ = np.arange(conn_weight.shape[1])
indices_ = np.tile(np.expand_dims(indices_, axis=0), (conn_weight.shape[0], 1))
self.node = brainmass.WilsonCowanStep(
80,
noise_E=brainmass.OUProcess(80, sigma=sigma, init=braintools.init.ZeroInit()),
noise_I=brainmass.OUProcess(80, sigma=sigma, init=braintools.init.ZeroInit()),
)
self.coupling = brainmass.DiffusiveCoupling(
self.node.prefetch_delay('rE', (delay_time * u.ms, indices_), init=braintools.init.Uniform(0, 0.05)),
self.node.prefetch('rE'),
conn_weight,
k=k
)
def update(self):
current = self.coupling()
rE = self.node(current)
return rE
def step_run(self, i):
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
return self.update()
def simulation(k, sigma):
net = Network(k=k, sigma=sigma)
brainstate.nn.init_all_states(net)
indices = np.arange(0, 1e3 * u.ms // brainstate.environ.get_dt())
exes = brainstate.transform.for_loop(net.step_run, indices)
fc = braintools.metric.functional_connectivity(exes)
return braintools.metric.matrix_correlation(target_fc, fc)
Model and Coupling#
We simulate 80 Wilson–Cowan nodes with OU noise on E and I. Diffusive coupling is applied on rE via DiffusiveCoupling:
Global gain k scales Cmat.
Delays are Dmat / signal_speed and handled with prefetch_delay.
The module’s update returns rE, used for FC computation.
@brainstate.transform.jit
def loss_fn(arr):
k, sigma = arr
return 1 - simulation(k, sigma)
Simulation and Loss#
The simulation runs the network for a short window, computes FC from excitatory activity, then returns correlation with target_fc. The loss is 1 - correlation so that lower is better. We wrap the two parameters (k, sigma) into a single array for SciPy.’s API and optionally JIT-compile for speed.
opt = braintools.optim.ScipyOptimizer(
loss_fn, bounds=[(0.5, 3.0), (0.0, 1.)], method='L-BFGS-B'
)
best_r = opt.minimize(n_iter=1)
print(best_r)
message: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
success: True
status: 0
fun: 1.0078195333480835
x: [ 5.000e-01 1.000e+00]
nit: 1
jac: [-9.956e-02 1.753e-02]
nfev: 6
njev: 6
hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>
SciPy Optimizer Setup#
We use braintools.optim.ScipyOptimizer as a thin wrapper around scipy.optimize.minimize. Bounds and method (e.g., L-BFGS-B, Nelder-Mead) can be customized. Increase n_iter or switch methods for robustness.
Tips:
Start with short simulations for fast iterations, then refine.
Consider multiple random restarts.
If gradients are unreliable (stochastic noise), prefer derivative-free methods (Nelder-Mead, Powell).
Notes#
Ensure JAX backend is configured to benefit from jit.
Set seeds or fix noise initial states for reproducibility when comparing runs.
You can extend the loss to multi-objective (e.g., also matching power spectra).