VariableDuration#
- class braintools.cogtask.VariableDuration(min_duration, max_duration, ctx_key, inputs=None, outputs=None, noise=None, on_enter=None, on_exit=None, name=None)[source]#
Declarative phase whose actual length is decided per trial.
The phase reserves a buffer slot of
ceil(max_duration / dt)timesteps and writes its content into the firstceil(ctx[ctx_key] / dt)of them. Anything past the trial’s actual length is automatically zeroed by the packed runtime and masked out in the trial mask. Compatible withbrainstate.transform.jitandbrainstate.transform.vmap2.ctx_keynames the trial state entry holding the sampled duration. The value can be a scalar number of milliseconds or a Quantity; in either case the runtime divides byctx.dtto obtain a timestep count.min_durationandmax_durationmust be Quantities with matching units;min_durationis the static lower bound used to floor the step count (always >= 1 timestep).Examples
>>> import brainunit as u >>> from braintools.cogtask import VariableDuration >>> phase = VariableDuration( ... min_duration=300 * u.ms, ... max_duration=1500 * u.ms, ... ctx_key='delay_duration', ... inputs={'fixation': 1.0}, ... outputs={'label': 0}, ... name='variable_delay', ... )
- get_duration(ctx)[source]#
Eager-mode duration. Reads
ctx[self._ctx_key]if available (and converts to a Python int viaint(jnp.asarray(...))); falls back tomax_stepsotherwise. Not used on the packed JIT path.- Return type:
- max_steps(ctx)[source]#
Static upper bound on this phase’s length in timesteps.
Must return a Python int with no dependence on traced values. Used by
Taskin variable-length mode to size shape-stable buffers. The default delegates toget_durationwhich is correct for fixed-duration phases. Variable-duration phases (e.g. those wrappingTruncExp/UniformDuration) override this to return the truncation upper bound divided byctx.dt.
- step_count(ctx)[source]#
Traced actual length of this phase in timesteps.
Returns a
jax.Arrayint32scalar. May depend onctx[...]values populated bytrial_init. Must satisfy0 <= step_count(ctx) <= max_steps(ctx)for every trial.The default returns a static value equal to
get_duration; that is correct for any phase whose actual length matches its upper bound. Variable-duration phases override this to compute the traced length fromctxstate without anyint(...)cast.- Return type:
Array