If#
- class braintools.cogtask.If(condition, then, else_=None, name='If')[source]#
Conditional phase selection based on a boolean condition.
Evaluates the condition at runtime and executes either the then phase or the else_ phase (if provided).
Examples
>>> # Match vs non-match response >>> phases = ( ... Sample(500 * u.ms) ... >> Delay(1000 * u.ms) ... >> Test(500 * u.ms) ... >> If( ... lambda ctx: ctx['match'], ... then=MatchResponse(500 * u.ms), ... else_=NonMatchResponse(500 * u.ms) ... ) ... )
>>> # Go/NoGo with no else branch >>> phases = ( ... Stimulus(500 * u.ms) ... >> If( ... lambda ctx: ctx['is_go'], ... then=Response(500 * u.ms) ... ) ... )
- Parameters:
- children()[source]#
Return the immediate child phases of a compound phase.
Leaf phases return
[]. Subclasses likeSequence,Repeat,Parallel,If,Switch,Whileoverride this so that theTaskcan traverse the whole tree to bind features.
- encode_inputs(ctx)[source]#
Fill ctx.inputs[phase_start:phase_end] with input encoding.
Called once per phase after duration is determined. Must modify ctx.inputs in-place.
- encode_outputs(ctx)[source]#
Fill ctx.outputs[phase_start:phase_end] with target encoding.
Called once per phase after duration is determined. Must modify ctx.outputs in-place.
- execute_packed(ctx)[source]#
Packed-mode branch using
jax.lax.cond.Both branches mutate
ctx.inputs/outputs/mask/t_cursor. To make the cond functional we (a) snapshotctx._stateso per-branch scratch writes don’t leak across branches, and (b) thread the buffer state throughlax.condas a pytree. Anything not in that pytree (e.g.ctx.phase_history) is best-effort metadata and may contain entries from both branches during tracing — it is not part of the trial’s tensor output.- Return type:
- max_steps(ctx)[source]#
Static upper bound on this phase’s length in timesteps.
Must return a Python int with no dependence on traced values. Used by
Taskin variable-length mode to size shape-stable buffers. The default delegates toget_durationwhich is correct for fixed-duration phases. Variable-duration phases (e.g. those wrappingTruncExp/UniformDuration) override this to return the truncation upper bound divided byctx.dt.
- step_count(ctx)[source]#
Traced actual length of this phase in timesteps.
Returns a
jax.Arrayint32scalar. May depend onctx[...]values populated bytrial_init. Must satisfy0 <= step_count(ctx) <= max_steps(ctx)for every trial.The default returns a static value equal to
get_duration; that is correct for any phase whose actual length matches its upper bound. Variable-duration phases override this to compute the traced length fromctxstate without anyint(...)cast.- Return type:
Array