brainmass.Fitter#
- class brainmass.Fitter(model, optimizer=None, *, loss_fn=None, objective=None, predict=None, backend='grad', callbacks=None, transient=None, search_space=None)#
Fit a model’s trainable parameters to data behind one
.fitcall.- Parameters:
model (brainstate.nn.Module) – The model to fit. Its trainable
Paramparameters (fit=True) are the optimization variables.optimizer (optional) –
Interpreted by
backend:'grad': abraintools.optim.OptaxOptimizerinstance (e.g.braintools.optim.Adam(lr=0.05)). Defaults toAdam(lr=1e-2).'nevergrad'/'scipy': an optionsdictforwarded to the optimizer constructor (e.g.{'method': 'DE', 'n_sample': 8}or{'method': 'L-BFGS-B'}), a method-namestr, orNone. The actual optimizer is constructed insidefit()(it needs the loss and bounds) and exposed afterwards asoptimizer.
loss_fn (callable, optional) –
loss_fn(model) -> (scalar_loss, aux). When given it is the entire loss (you own any regularization) andobjective/predict/targetare unused. Mutually exclusive withpredict.objective (callable, optional) –
objective(prediction, target) -> scalar(e.g. frombrainmass.objectives). Used withpredict. Defaults tobrainmass.objectives.timeseries_rmse().predict (callable, optional) –
predict(model) -> prediction; typically abrainmass.Simulatorclosure. Required unlessloss_fnis given. The objective-path loss isobjective(prediction[transient:], target) + model.reg_loss().backend ({'grad', 'nevergrad', 'scipy'}, default 'grad') – Which optimizer backend to use.
callbacks (list of callable, optional) – Each
callback(info) -> bool | Noneis called once per step withinfo = {'step', 'loss', 'best_loss', 'model'}. ReturningTruestops the run early (gradbackend only).transient (int, optional) – Number of leading samples discarded from the prediction (axis 0) before the objective is applied.
Nonekeeps the whole prediction.search_space (dict, optional) –
{name: (low, high)}constrained bounds for the derivative-free backends, overriding/augmenting the bounds derived from each parameter’s transform.
See also
brainmass.Simulatorbuilds the
predictclosure.brainmass.objectivescomposable objective callables.
Notes
The
gradbackend reproduces the canonical hand-rolled loop exactly:model.states(ParamState)are registered as trainable weights, the loss is evaluated insidemodel.param_precompute(), andbrainstate.transform.grad(..., has_aux=True, return_value=True)feedsoptimizer.step.model.reg_loss()is added automatically. The derivative-free backends evaluate one candidate at a time (setting parameters viaParam.set_valuethen runningpredict) –vmapover parameter writes is unsupported, so candidates are looped, which is why these backends suit a small number of scalar parameters.Examples
>>> import braintools, brainstate, brainunit as u >>> import jax.numpy as jnp >>> import brainmass >>> from brainstate.nn import Param, SigmoidT >>> class Toy(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.k = Param(1.0, t=SigmoidT(0.5, 3.0)) ... def update(self): ... return self.k.value() >>> def predict(m): ... sim = brainmass.Simulator(m, dt=0.1 * u.ms) ... return jnp.mean(sim.run(1.0 * u.ms, monitors=None)['output']) >>> fitter = brainmass.Fitter(Toy(), braintools.optim.Adam(lr=0.1), predict=predict) >>> result = fitter.fit(target=jnp.asarray(2.0), n_steps=30) >>> result.backend 'grad' >>> list(result.best_params) ['k'] >>> bool(result.best_loss <= result.history[0]) True
- __init__(model, optimizer=None, *, loss_fn=None, objective=None, predict=None, backend='grad', callbacks=None, transient=None, search_space=None)[source]#
Methods
__init__(model[, optimizer, loss_fn, ...])fit([target, n_steps, verbose])Run the optimization and return a
FitResult.Attributes
The underlying optimizer (constructed lazily for derivative-free backends).
- __init__(model, optimizer=None, *, loss_fn=None, objective=None, predict=None, backend='grad', callbacks=None, transient=None, search_space=None)[source]#
- fit(target=None, n_steps=100, *, verbose=False)[source]#
Run the optimization and return a
FitResult.- Parameters:
target (Any, optional) – The data the objective compares the prediction against. Unused when a
loss_fnwas supplied (it consumes data itself).n_steps (int, default 100) – Optimization iterations: optax steps (
grad), generations (nevergrad), or random restarts (scipy). Must be >= 1.verbose (bool, default False) – Print per-iteration progress.
- Returns:
The best-seen loss, parameters, history, and the fitted model.
- Return type:
- property optimizer#
The underlying optimizer (constructed lazily for derivative-free backends).