VariableDuration#

class braintools.cogtask.VariableDuration(min_duration, max_duration, ctx_key, inputs=None, outputs=None, noise=None, on_enter=None, on_exit=None, name=None)[source]#

Declarative phase whose actual length is decided per trial.

The phase reserves a buffer slot of ceil(max_duration / dt) timesteps and writes its content into the first ceil(ctx[ctx_key] / dt) of them. Anything past the trial’s actual length is automatically zeroed by the packed runtime and masked out in the trial mask. Compatible with brainstate.transform.jit and brainstate.transform.vmap2.

ctx_key names the trial state entry holding the sampled duration. The value can be a scalar number of milliseconds or a Quantity; in either case the runtime divides by ctx.dt to obtain a timestep count. min_duration and max_duration must be Quantities with matching units; min_duration is the static lower bound used to floor the step count (always >= 1 timestep).

Examples

>>> import brainunit as u
>>> from braintools.cogtask import VariableDuration
>>> phase = VariableDuration(
...     min_duration=300 * u.ms,
...     max_duration=1500 * u.ms,
...     ctx_key='delay_duration',
...     inputs={'fixation': 1.0},
...     outputs={'label': 0},
...     name='variable_delay',
... )
get_duration(ctx)[source]#

Eager-mode duration. Reads ctx[self._ctx_key] if available (and converts to a Python int via int(jnp.asarray(...))); falls back to max_steps otherwise. Not used on the packed JIT path.

Return type:

int

max_steps(ctx)[source]#

Static upper bound on this phase’s length in timesteps.

Must return a Python int with no dependence on traced values. Used by Task in variable-length mode to size shape-stable buffers. The default delegates to get_duration which is correct for fixed-duration phases. Variable-duration phases (e.g. those wrapping TruncExp/UniformDuration) override this to return the truncation upper bound divided by ctx.dt.

Parameters:

ctx (Context) – A stub or trial context providing ctx.dt. The default implementation does not read ctx.rng or trial state.

Returns:

Upper bound on number of timesteps for this phase.

Return type:

int

step_count(ctx)[source]#

Traced actual length of this phase in timesteps.

Returns a jax.Array int32 scalar. May depend on ctx[...] values populated by trial_init. Must satisfy 0 <= step_count(ctx) <= max_steps(ctx) for every trial.

The default returns a static value equal to get_duration; that is correct for any phase whose actual length matches its upper bound. Variable-duration phases override this to compute the traced length from ctx state without any int(...) cast.

Return type:

Array