Creating an Objective#

An objective scores how well a simulated trajectory matches data. brainmass ships a toolkit in brainmass.objectives (timeseries_rmse, fc_corr, fcd_*, …). This guide shows how to write your own objective so it behaves exactly like the built-ins: composable, unit-aware, and usable with Fitter across all three optimizer backends (gradient, Nevergrad, SciPy).

To merely combine the existing objectives, see Compose a Custom Objective. This guide is the authoring contract.

The contract#

An objective is a builder: a function that takes configuration and returns a small callable(prediction, target) -> scalar. The returned callable must be:

  • pure and traced-array-friendly – built from jax.numpy / brainunit ops so it survives jit, grad, and vmap (the three backends each need a different one of these).

  • unit-aware – strip units with brainunit.get_magnitude where the metric is scale-invariant (correlations, cosine, a variance ratio); keep them on a difference you want unit-checked (subtracting mV from Hz should raise).

  • a single scalar – the optimizers minimise a scalar.

A prediction / target is a (time, regions) trajectory – the natural output of Simulator.run. By convention a builder takes as_loss= so the same metric can be a score to maximise or 1 - score (or its negative) to minimise.

def variance_match(as_loss=True):
    """Match the overall temporal variance of two signals.

    A scale-sensitive summary: invariant in *units* (we strip them)
    but sensitive to amplitude. Returns a builder, like brainmass.objectives.*
    """
    def objective(prediction, target):
        var_p = jnp.var(u.get_magnitude(prediction))
        var_t = jnp.var(u.get_magnitude(target))
        d = (var_p - var_t) ** 2
        return d if as_loss else -d
    return objective


# It is unit-aware: mV and Hz inputs both work, identity gives 0.
loss = variance_match()
x_mV = jnp.ones((50, 3)) * u.mV
print('identity loss (mV) :', float(loss(x_mV, x_mV)))
x_Hz = (jnp.ones((50, 3)) * 2.0) * u.Hz
print('identity loss (Hz) :', float(loss(x_Hz, x_Hz)))
identity loss (mV) : 0.0
identity loss (Hz) : 0.0

It composes with the built-ins#

Because a custom objective has the exact (prediction, target) -> scalar shape, brainmass.objectives.combine mixes it with built-ins as a weighted sum – a common pattern when fitting to several features at once (e.g. FC correlation and an amplitude term).

from brainmass import objectives

mixed = objectives.combine(
    (1.0, objectives.fc_corr(as_loss=True)),   # match functional connectivity
    (0.5, variance_match(as_loss=True)),       # ... and overall amplitude
)
rng = np.random.default_rng(0)
a = jnp.asarray(rng.standard_normal((200, 4)))
print('combined loss, identity :', round(float(mixed(a, a)), 6))   # both terms 0
b = a * 1.5 + 0.2
print('combined loss, perturbed:', round(float(mixed(a, b)), 6))
combined loss, identity : 0.0
combined loss, perturbed: 0.782201

It works across all three Fitter backends#

The payoff of the contract: write the objective once, swap the backend. The objective is the same callable for all three; only the optimizer argument and the model’s parameter bounds differ.

  • grad differentiates through the objective – needs it to be a smooth jax function (ours is).

  • nevergrad / scipy are derivative-free and search a bounded box. They derive that box from the trainable Param’s transform, so the fitted parameter needs a finite transform interval – SigmoidT(lower, upper) gives one.

We fit the Hopf bifurcation parameter a so its settled limit-cycle variance matches a target generated at a* = 1.0.

from brainstate.nn import SigmoidT

def make_model(a0=0.3):
    # SigmoidT(0.1, 2.0) -> a bounded, trainable a in [0.1, 2.1]; the kick avoids
    # the unstable fixed point so the limit cycle actually has amplitude.
    return brainmass.HopfStep(
        3, a=Param(a0, t=SigmoidT(0.1, 2.0)), w=0.3,
        init_x=braintools.init.Constant(0.5),
    )

def predict(m):
    sim = brainmass.Simulator(m, dt=0.1 * u.ms)
    return sim.run(200. * u.ms, monitors=['x'], transient=50 * u.ms)['x']

target = predict(make_model(1.0))      # ground truth at a* = 1.0
objective = variance_match(as_loss=True)
backends = [
    ('grad',      braintools.optim.Adam(lr=0.05), 40),
    ('scipy',     {'method': 'Nelder-Mead'},      4),
    ('nevergrad', {'method': 'DE', 'n_sample': 6}, 4),
]

results = {}
for backend, opt, n in backends:
    m = make_model(0.3)
    fitter = brainmass.Fitter(m, opt, objective=objective,
                              predict=predict, backend=backend)
    res = fitter.fit(target=target, n_steps=n)
    a_fit = float(list(res.best_params.values())[0])
    results[backend] = a_fit
    print(f'{backend:>9s}:  a = {a_fit:.4f}   best_loss = {res.best_loss:.3e}')

print('\ntrue a* = 1.0; all three recovered it from the SAME objective.')
     grad:  a = 0.9495   best_loss = 1.135e-03
/home/chaoming/miniconda3/lib/python3.13/site-packages/braintools/optim/_scipy_optimizer.py:284: RuntimeWarning: Method Nelder-Mead does not use gradient information (jac).
  results = minimize(
    scipy:  a = 1.0000   best_loss = 5.288e-11
nevergrad:  a = 1.0042   best_loss = 4.530e-06

true a* = 1.0; all three recovered it from the SAME objective.

All three backends drive a from 0.3 to ~1.0 using one objective callable. That is the whole point: the objective is decoupled from how it is optimised.

Notes for a gradient-friendly objective#

  • Smoothness matters for grad. A max/argmax (like a KS statistic) is non-smooth – usable for evaluation but a poor gradient loss. Prefer a smooth surrogate (an integral / Wasserstein-style distance) when the objective drives the gradient backend. The built-in fcd_ks vs fcd_wasserstein pair is exactly this trade-off.

  • Fit a well-conditioned summary, not a phase-degenerate raw oscillatory trace (see Building a Data-Driven Workflow).

  • Reuse braintools.metric rather than re-implementing metrics; the built-in objectives are thin wrappers over it (functional_connectivity, matrix_correlation, power_spectral_density, …).

See Also#