While#
- class braintools.cogtask.While(condition, body, max_iterations=100, name='While')[source]#
Loop phase while condition is true.
Useful for tasks with variable numbers of repetitions, such as evidence accumulation until a threshold is reached.
Note: The duration computed by get_duration uses max_iterations as an upper bound, but actual execution may be shorter.
Examples
>>> # Evidence accumulation until threshold >>> phases = ( ... Fixation(500 * u.ms) ... >> While( ... lambda ctx: ctx.get('accumulated_evidence', 0) < threshold, ... body=EvidenceSample(50 * u.ms), ... max_iterations=50 ... ) ... >> Response(500 * u.ms) ... )
>>> # Repeated sampling with early termination >>> phases = While( ... lambda ctx: ctx.get('sample_count', 0) < ctx['required_samples'], ... body=Sample(100 * u.ms), ... max_iterations=20 ... )
- Parameters:
- children()[source]#
Return the immediate child phases of a compound phase.
Leaf phases return
[]. Subclasses likeSequence,Repeat,Parallel,If,Switch,Whileoverride this so that theTaskcan traverse the whole tree to bind features.
- encode_inputs(ctx)[source]#
Fill ctx.inputs[phase_start:phase_end] with input encoding.
Called once per phase after duration is determined. Must modify ctx.inputs in-place.
- encode_outputs(ctx)[source]#
Fill ctx.outputs[phase_start:phase_end] with target encoding.
Called once per phase after duration is determined. Must modify ctx.outputs in-place.
- execute_packed(ctx)[source]#
Packed-mode loop. The condition must return a Python
bool; tracer-valued conditions are not supported here (would requirelax.while_loopwith a state-as-pytree wrapper).- Return type:
- get_duration(ctx)[source]#
Estimate duration using max_iterations.
Note: Actual duration may be less if condition becomes False.
- Return type:
- max_steps(ctx)[source]#
Static upper bound on this phase’s length in timesteps.
Must return a Python int with no dependence on traced values. Used by
Taskin variable-length mode to size shape-stable buffers. The default delegates toget_durationwhich is correct for fixed-duration phases. Variable-duration phases (e.g. those wrappingTruncExp/UniformDuration) override this to return the truncation upper bound divided byctx.dt.
- step_count(ctx)[source]#
Traced actual length of this phase in timesteps.
Returns a
jax.Arrayint32scalar. May depend onctx[...]values populated bytrial_init. Must satisfy0 <= step_count(ctx) <= max_steps(ctx)for every trial.The default returns a static value equal to
get_duration; that is correct for any phase whose actual length matches its upper bound. Variable-duration phases override this to compute the traced length fromctxstate without anyint(...)cast.- Return type:
Array