brainmass.Simulator#
- class brainmass.Simulator(model, dt=None)#
Drive a
Dynamics/Moduleand collect monitored trajectories.The simulator wraps the standard brainmass run loop – set
dt, initialise states, step the model insideenviron.context(i=, t=)withbrainstate.transform.for_loop(), and stack the recorded values – in a singlerun()call. It reimplements none of that machinery; it only composes it.- Parameters:
model (brainstate.nn.Dynamics or brainstate.nn.Module) – The model to drive. Exposed afterwards as
model. May be a single node or a network whoseupdatereads its own internal coupling.dt (brainunit.Quantity, optional) – Integration time step. If
None(default), the step is read from the global environment (brainstate.environ.get('dt')) atrun()time; if that is also unset,run()raisesValueError.
See also
brainmass.objectivescomposable loss builders over simulated trajectories.
Notes
The result of
run()is a plaindict(a valid JAX pytree, so it is safe to return throughjit/grad/vmap) mapping each monitor name to its stacked trajectory, plus a'ts'time axis.ts[k]is the simulation time at the end of thek-th recorded step (the recorded value is the post-updatestate).Examples
>>> import brainmass >>> import brainunit as u >>> node = brainmass.HopfStep(2, a=-0.2) >>> sim = brainmass.Simulator(node, dt=0.1 * u.ms) >>> res = sim.run(10.0 * u.ms, monitors=['x']) >>> res['x'].shape (100, 2) >>> res['ts'].shape (100,)
A derived observable (here
E - I) is monitored with a callable:>>> jr = brainmass.JansenRitStep(in_size=1) >>> sim = brainmass.Simulator(jr, dt=0.1 * u.ms) >>> res = sim.run(5.0 * u.ms, monitors=lambda m: m.eeg(), transient=1.0 * u.ms) >>> res['output'].shape (40, 1)
Methods
__init__(model[, dt])run(duration, *[, inputs, monitors, ...])Simulate
modelfordurationand return the recorded trajectories.- run(duration, *, inputs=None, monitors=None, transient=None, sample_every=None, batch_size=None, init_states=True, jit=True)[source]#
Simulate
modelfordurationand return the recorded trajectories.- Parameters:
duration (brainunit.Quantity) – Total simulated time. The number of integration steps is
int(duration / dt); a non-integer multiple is floored with a warning.inputs (None, array-like or callable, optional) –
External drive forwarded to
model.update:None(default): callmodel.update()with no arguments.array-like of shape
(n_steps, ...): rowiis passed as the single argumentmodel.update(inputs[i]).callable(i, t): its return is splatted –model.update(*r)if it returns a tuple, elsemodel.update(r).
monitors (None, list of str, callable or dict, optional) –
What to record each step:
None(default): the return value ofupdate(), under key'output'.list of state-attribute names (e.g.
['rE']): eachgetattr(model, name).value.callable(model) -> value: a derived observable, under key'output'(e.g.lambda m: m.eeg()).dict mapping output name to a state-attribute name or a
callable(model).
transient (None, brainunit.Quantity or int, optional) – Leading transient to discard from the outputs, as a duration or a step count. Must be shorter than the run.
sample_every (None or int, optional) – Record every
k-th step (output downsampling; generalises the Jansen-Rit TR loop).Nonerecords every step.batch_size (None or int, optional) – If given, initialise states with this batch size (batched initial conditions); outputs gain a leading batch axis.
init_states (bool, default True) – Call
init_all_statesbefore running. SetFalseto continue from the model’s current state.jit (bool, default True) – Compile the run with
brainstate.transform.jit().
- Returns:
Maps each monitor name (or
'output') to its stacked trajectory and'ts'to the time axis. Each trajectory has a leading axis of length(n_steps - transient) // sample_every.- Return type:
- Raises:
ValueError – If
dtis unset,transientis not shorter than the run, a monitored name is not a model attribute,sample_every< 1, or theinputslength does not match the step count.