brainmass.Simulator#

class brainmass.Simulator(model, dt=None)#

Drive a Dynamics / Module and collect monitored trajectories.

The simulator wraps the standard brainmass run loop – set dt, initialise states, step the model inside environ.context(i=, t=) with brainstate.transform.for_loop(), and stack the recorded values – in a single run() call. It reimplements none of that machinery; it only composes it.

Parameters:
  • model (brainstate.nn.Dynamics or brainstate.nn.Module) – The model to drive. Exposed afterwards as model. May be a single node or a network whose update reads its own internal coupling.

  • dt (brainunit.Quantity, optional) – Integration time step. If None (default), the step is read from the global environment (brainstate.environ.get('dt')) at run() time; if that is also unset, run() raises ValueError.

See also

brainmass.objectives

composable loss builders over simulated trajectories.

Notes

The result of run() is a plain dict (a valid JAX pytree, so it is safe to return through jit / grad / vmap) mapping each monitor name to its stacked trajectory, plus a 'ts' time axis. ts[k] is the simulation time at the end of the k-th recorded step (the recorded value is the post-update state).

Examples

>>> import brainmass
>>> import brainunit as u
>>> node = brainmass.HopfStep(2, a=-0.2)
>>> sim = brainmass.Simulator(node, dt=0.1 * u.ms)
>>> res = sim.run(10.0 * u.ms, monitors=['x'])
>>> res['x'].shape
(100, 2)
>>> res['ts'].shape
(100,)

A derived observable (here E - I) is monitored with a callable:

>>> jr = brainmass.JansenRitStep(in_size=1)
>>> sim = brainmass.Simulator(jr, dt=0.1 * u.ms)
>>> res = sim.run(5.0 * u.ms, monitors=lambda m: m.eeg(), transient=1.0 * u.ms)
>>> res['output'].shape
(40, 1)
__init__(model, dt=None)[source]#

Methods

__init__(model[, dt])

run(duration, *[, inputs, monitors, ...])

Simulate model for duration and return the recorded trajectories.

__init__(model, dt=None)[source]#
run(duration, *, inputs=None, monitors=None, transient=None, sample_every=None, batch_size=None, init_states=True, jit=True)[source]#

Simulate model for duration and return the recorded trajectories.

Parameters:
  • duration (brainunit.Quantity) – Total simulated time. The number of integration steps is int(duration / dt); a non-integer multiple is floored with a warning.

  • inputs (None, array-like or callable, optional) –

    External drive forwarded to model.update:

    • None (default): call model.update() with no arguments.

    • array-like of shape (n_steps, ...): row i is passed as the single argument model.update(inputs[i]).

    • callable(i, t): its return is splatted – model.update(*r) if it returns a tuple, else model.update(r).

  • monitors (None, list of str, callable or dict, optional) –

    What to record each step:

    • None (default): the return value of update(), under key 'output'.

    • list of state-attribute names (e.g. ['rE']): each getattr(model, name).value.

    • callable(model) -> value: a derived observable, under key 'output' (e.g. lambda m: m.eeg()).

    • dict mapping output name to a state-attribute name or a callable(model).

  • transient (None, brainunit.Quantity or int, optional) – Leading transient to discard from the outputs, as a duration or a step count. Must be shorter than the run.

  • sample_every (None or int, optional) – Record every k-th step (output downsampling; generalises the Jansen-Rit TR loop). None records every step.

  • batch_size (None or int, optional) – If given, initialise states with this batch size (batched initial conditions); outputs gain a leading batch axis.

  • init_states (bool, default True) – Call init_all_states before running. Set False to continue from the model’s current state.

  • jit (bool, default True) – Compile the run with brainstate.transform.jit().

Returns:

Maps each monitor name (or 'output') to its stacked trajectory and 'ts' to the time axis. Each trajectory has a leading axis of length (n_steps - transient) // sample_every.

Return type:

dict

Raises:

ValueError – If dt is unset, transient is not shorter than the run, a monitored name is not a model attribute, sample_every < 1, or the inputs length does not match the step count.