Parameter Fitting#

This tutorial covers optimizing model parameters to match empirical data.

Overview#

Parameter fitting workflow:

  1. Define a loss function comparing model output to data

  2. Choose an optimization method

  3. Run optimization to find best parameters

  4. Validate fitted parameters

Loss Functions#

Functional Connectivity Loss#

Most common for fMRI data:

import jax.numpy as jnp

def fc_loss(params, SC, FC_empirical):
    """Loss based on functional connectivity matching"""

    # Run simulation with params
    bold_sim = simulate_network(params, SC)

    # Compute simulated FC
    FC_sim = jnp.corrcoef(bold_sim.T)

    # Compare to empirical FC
    # Option 1: Correlation
    FC_corr = jnp.corrcoef(FC_sim.flatten(), FC_empirical.flatten())[0, 1]
    loss = 1.0 - FC_corr  # minimize (1 - correlation)

    # Option 2: MSE
    # loss = jnp.mean((FC_sim - FC_empirical) ** 2)

    return loss

Time Series Loss#

For EEG/MEG or other time-domain data:

def timeseries_loss(params, data_empirical):
    """Direct time series matching"""

    # Simulate
    data_sim = simulate_eeg(params)

    # MSE loss
    loss = jnp.mean((data_sim - data_empirical) ** 2)

    return loss

Power Spectrum Loss#

Match frequency content:

from scipy import signal

def psd_loss(params, psd_empirical, freqs_empirical):
    """Match power spectral density"""

    # Simulate and compute PSD
    ts_sim = simulate(params)
    freqs_sim, psd_sim = signal.welch(ts_sim, fs=1000)

    # Interpolate to match empirical frequencies
    psd_sim_interp = jnp.interp(freqs_empirical, freqs_sim, psd_sim)

    # Log-space MSE (better for power spectra)
    loss = jnp.mean((jnp.log(psd_sim_interp) - jnp.log(psd_empirical)) ** 2)

    return loss

Optimization Methods#

Gradient-Free (Nevergrad)#

Best for non-differentiable objectives:

import nevergrad as ng
import brainmass
import jax.numpy as jnp

# Load data
SC = jnp.load('SC.npy')
FC_emp = jnp.load('FC_empirical.npy')

# Define simulation function
def simulate_network(coupling_strength):
    nodes = brainmass.WongWangModel(in_size=90)
    coupling = brainmass.DiffusiveCoupling(conn=SC, k=coupling_strength)
    bold = brainmass.BOLDSignal(in_size=90)

    nodes.init_all_states()
    coupling.init_all_states()
    bold.init_all_states()

    # Simulate (simplified)
    neural_ts = []
    for t in range(10000):
        S_E = nodes.S_E.value
        coupled = coupling(S_E, S_E)
        out = nodes.update(S_E_ext=coupled)
        neural_ts.append(out)

    neural_ts = jnp.stack(neural_ts)

    # Generate BOLD
    bold_ts = []
    for z in neural_ts:
        bold.update(z=z)
        bold_ts.append(bold.bold())

    return jnp.stack(bold_ts)[::2000]  # downsample

# Define loss
def objective(coupling_strength):
    bold_sim = simulate_network(coupling_strength)
    FC_sim = jnp.corrcoef(bold_sim.T)
    corr = jnp.corrcoef(FC_sim.flatten(), FC_emp.flatten())[0, 1]
    return 1.0 - corr  # minimize

# Setup optimization
instrum = ng.p.Scalar(init=0.2, lower=0.0, upper=1.0)
optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=50)

# Run optimization
for i in range(50):
    x = optimizer.ask()
    loss = objective(x.value)
    optimizer.tell(x, loss)
    print(f"Iteration {i}: k={x.value:.3f}, loss={loss:.4f}")

# Best parameters
recommendation = optimizer.provide_recommendation()
best_k = recommendation.value
print(f"Best coupling strength: {best_k:.3f}")

Gradient-Based (JAX + Optax)#

For differentiable models:

import jax
import jax.numpy as jnp
import optax
import brainmass

# Define differentiable simulation
@jax.jit
def simulate_and_loss(k, SC, FC_emp):
    # Simplified differentiable simulation
    # (full network simulation needs careful state management)

    # ... simulation code ...

    FC_sim = jnp.corrcoef(bold_sim.T)
    loss = jnp.mean((FC_sim - FC_emp) ** 2)

    return loss

# Setup optimizer
learning_rate = 1e-3
optimizer = optax.adam(learning_rate)

# Initial parameters
params = {'k': 0.2}
opt_state = optimizer.init(params)

# Training loop
for epoch in range(100):
    # Compute gradients
    loss, grads = jax.value_and_grad(simulate_and_loss)(
        params['k'], SC, FC_emp
    )

    # Update parameters
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: k={params['k']:.3f}, loss={loss:.4f}")

SciPy Optimizers#

Good for moderate-dimensional problems:

from scipy.optimize import minimize
import jax.numpy as jnp

def objective(params):
    k, noise_sigma = params
    # ... simulate ...
    loss = fc_loss(k, noise_sigma, FC_emp)
    return loss

# Initial guess
x0 = [0.2, 0.01]

# Optimize
result = minimize(
    objective,
    x0,
    method='Nelder-Mead',
    options={'maxiter': 100}
)

best_params = result.x
print(f"Best parameters: k={best_params[0]:.3f}, sigma={best_params[1]:.4f}")

Multi-Parameter Optimization#

Optimizing Multiple Parameters#

import nevergrad as ng

# Define parameter space
instrum = ng.p.Instrumentation(
    coupling_k=ng.p.Scalar(init=0.2, lower=0.0, upper=1.0),
    noise_sigma=ng.p.Scalar(init=0.01, lower=0.0, upper=0.1),
    tau_E=ng.p.Scalar(init=10.0, lower=5.0, upper=50.0),
    tau_I=ng.p.Scalar(init=20.0, lower=10.0, upper=100.0),
)

optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=200)

def multi_param_objective(coupling_k, noise_sigma, tau_E, tau_I):
    # Simulate with all parameters
    # ...
    return loss

# Optimize
for i in range(200):
    x = optimizer.ask()
    loss = multi_param_objective(**x.kwargs)
    optimizer.tell(x, loss)

Constrained Parameters#

Use ArrayParam for constraints:

import brainmass
import braintools

# Create constrained parameter (must be positive)
tau_param = brainmass.ArrayParam(
    data=10.0,  # initial value
    transform=braintools.SoftplusTransform(),
)

# Optimize in unconstrained space
def objective(tau_unconstrained):
    tau_param.value = tau_unconstrained
    tau_actual = tau_param.data  # always positive

    # Use tau_actual in simulation
    # ...
    return loss

Best Practices#

  1. Start Simple - Fit one parameter at a time initially - Add complexity gradually

  2. Use Multiple Initializations - Run optimization from different starting points - Avoid local minima

  3. Normalize Loss - Scale loss components to similar magnitude - Prevents one term dominating

  4. Validate on Hold-Out Data - Test fitted parameters on unseen data - Check for overfitting

  5. Monitor Convergence - Plot loss vs iteration - Stop when loss plateaus

  6. Set Reasonable Bounds - Use prior knowledge for parameter ranges - Prevents unphysical values

Common Issues#

Optimization Stuck in Local Minimum:

  • Use global optimizer (Nevergrad)

  • Try multiple initializations

  • Increase budget

Loss Not Decreasing:

  • Check loss function implementation

  • Verify simulation is correct

  • Reduce learning rate (gradient-based)

Parameters Unrealistic:

  • Add constraints/bounds

  • Use regularization

  • Check units

Slow Optimization:

  • Reduce simulation time

  • Use simpler model for initial fits

  • Parallelize evaluations (Nevergrad supports this)

Complete Example#

Full FC-Based Parameter Fitting#

import brainmass
import jax.numpy as jnp
import brainstate
import nevergrad as ng

# Load data
SC = jnp.load('SC_AAL90.npy')
FC_emp = jnp.load('FC_empirical.npy')

# Simulation function
def simulate_bold_fc(coupling_k, noise_sigma):
    N = 90
    T = 600000  # 10 min at 1ms

    # Create network
    nodes = brainmass.WongWangModel(in_size=N)
    coupling = brainmass.DiffusiveCoupling(conn=SC, k=coupling_k)
    nodes.noise_E = brainmass.OUProcess(
        in_size=N,
        sigma=noise_sigma,
        tau=100.0,
    )
    bold = brainmass.BOLDSignal(in_size=N)

    # Initialize
    nodes.init_all_states()
    coupling.init_all_states()
    bold.init_all_states()

    # Simulate
    neural_ts = []
    for t in range(T):
        S_E = nodes.S_E.value
        coupled = coupling(S_E, S_E)
        out = nodes.update(S_E_ext=coupled)
        neural_ts.append(out)

        if t % 10000 == 0:
            print(f"  Sim: {t}/{T}")

    neural_ts = jnp.stack(neural_ts)

    # BOLD
    bold_ts = []
    for z in neural_ts:
        bold.update(z=z)
        bold_ts.append(bold.bold())

    bold_downsampled = jnp.stack(bold_ts)[::2000]

    # FC
    FC_sim = jnp.corrcoef(bold_downsampled.T)

    return FC_sim

# Loss function
def objective(coupling_k, noise_sigma):
    print(f"Trying k={coupling_k:.3f}, sigma={noise_sigma:.4f}")

    FC_sim = simulate_bold_fc(coupling_k, noise_sigma)
    corr = jnp.corrcoef(FC_sim.flatten(), FC_emp.flatten())[0, 1]
    loss = 1.0 - corr

    print(f"  Loss: {loss:.4f}, FC corr: {corr:.4f}")
    return loss

# Optimization
instrum = ng.p.Instrumentation(
    coupling_k=ng.p.Scalar(init=0.2, lower=0.0, upper=1.0),
    noise_sigma=ng.p.Scalar(init=0.01, lower=0.0, upper=0.1),
)

optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=50)

for i in range(50):
    print(f"\n=== Iteration {i+1}/50 ===")
    x = optimizer.ask()
    loss = objective(**x.kwargs)
    optimizer.tell(x, loss)

# Best result
best = optimizer.provide_recommendation()
print(f"\nBest parameters:")
print(f"  Coupling k: {best.kwargs['coupling_k']:.3f}")
print(f"  Noise sigma: {best.kwargs['noise_sigma']:.4f}")

Next Steps#

  • Try parameter fitting on your own data

  • Explore different loss functions

  • Compare optimization methods

  • Examples for advanced optimization examples

See Also#

  • Forward Modeling - Getting observable data from models

  • Nevergrad documentation

  • Optax documentation