brainmass.Fitter#

class brainmass.Fitter(model, optimizer=None, *, loss_fn=None, objective=None, predict=None, backend='grad', callbacks=None, transient=None, search_space=None)#

Fit a model’s trainable parameters to data behind one .fit call.

Parameters:
  • model (brainstate.nn.Module) – The model to fit. Its trainable Param parameters (fit=True) are the optimization variables.

  • optimizer (optional) –

    Interpreted by backend:

    • 'grad': a braintools.optim.OptaxOptimizer instance (e.g. braintools.optim.Adam(lr=0.05)). Defaults to Adam(lr=1e-2).

    • 'nevergrad' / 'scipy': an options dict forwarded to the optimizer constructor (e.g. {'method': 'DE', 'n_sample': 8} or {'method': 'L-BFGS-B'}), a method-name str, or None. The actual optimizer is constructed inside fit() (it needs the loss and bounds) and exposed afterwards as optimizer.

  • loss_fn (callable, optional) – loss_fn(model) -> (scalar_loss, aux). When given it is the entire loss (you own any regularization) and objective / predict / target are unused. Mutually exclusive with predict.

  • objective (callable, optional) – objective(prediction, target) -> scalar (e.g. from brainmass.objectives). Used with predict. Defaults to brainmass.objectives.timeseries_rmse().

  • predict (callable, optional) – predict(model) -> prediction; typically a brainmass.Simulator closure. Required unless loss_fn is given. The objective-path loss is objective(prediction[transient:], target) + model.reg_loss().

  • backend ({'grad', 'nevergrad', 'scipy'}, default 'grad') – Which optimizer backend to use.

  • callbacks (list of callable, optional) – Each callback(info) -> bool | None is called once per step with info = {'step', 'loss', 'best_loss', 'model'}. Returning True stops the run early (grad backend only).

  • transient (int, optional) – Number of leading samples discarded from the prediction (axis 0) before the objective is applied. None keeps the whole prediction.

  • search_space (dict, optional) – {name: (low, high)} constrained bounds for the derivative-free backends, overriding/augmenting the bounds derived from each parameter’s transform.

See also

brainmass.Simulator

builds the predict closure.

brainmass.objectives

composable objective callables.

Notes

The grad backend reproduces the canonical hand-rolled loop exactly: model.states(ParamState) are registered as trainable weights, the loss is evaluated inside model.param_precompute(), and brainstate.transform.grad(..., has_aux=True, return_value=True) feeds optimizer.step. model.reg_loss() is added automatically. The derivative-free backends evaluate one candidate at a time (setting parameters via Param.set_value then running predict) – vmap over parameter writes is unsupported, so candidates are looped, which is why these backends suit a small number of scalar parameters.

Examples

>>> import braintools, brainstate, brainunit as u
>>> import jax.numpy as jnp
>>> import brainmass
>>> from brainstate.nn import Param, SigmoidT
>>> class Toy(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.k = Param(1.0, t=SigmoidT(0.5, 3.0))
...     def update(self):
...         return self.k.value()
>>> def predict(m):
...     sim = brainmass.Simulator(m, dt=0.1 * u.ms)
...     return jnp.mean(sim.run(1.0 * u.ms, monitors=None)['output'])
>>> fitter = brainmass.Fitter(Toy(), braintools.optim.Adam(lr=0.1), predict=predict)
>>> result = fitter.fit(target=jnp.asarray(2.0), n_steps=30)
>>> result.backend
'grad'
>>> list(result.best_params)
['k']
>>> bool(result.best_loss <= result.history[0])
True
__init__(model, optimizer=None, *, loss_fn=None, objective=None, predict=None, backend='grad', callbacks=None, transient=None, search_space=None)[source]#

Methods

__init__(model[, optimizer, loss_fn, ...])

fit([target, n_steps, verbose])

Run the optimization and return a FitResult.

Attributes

optimizer

The underlying optimizer (constructed lazily for derivative-free backends).

__init__(model, optimizer=None, *, loss_fn=None, objective=None, predict=None, backend='grad', callbacks=None, transient=None, search_space=None)[source]#
fit(target=None, n_steps=100, *, verbose=False)[source]#

Run the optimization and return a FitResult.

Parameters:
  • target (Any, optional) – The data the objective compares the prediction against. Unused when a loss_fn was supplied (it consumes data itself).

  • n_steps (int, default 100) – Optimization iterations: optax steps (grad), generations (nevergrad), or random restarts (scipy). Must be >= 1.

  • verbose (bool, default False) – Print per-iteration progress.

Returns:

The best-seen loss, parameters, history, and the fitted model.

Return type:

FitResult

property optimizer#

The underlying optimizer (constructed lazily for derivative-free backends).