If#

class braintools.cogtask.If(condition, then, else_=None, name='If')[source]#

Conditional phase selection based on a boolean condition.

Evaluates the condition at runtime and executes either the then phase or the else_ phase (if provided).

Examples

>>> # Match vs non-match response
>>> phases = (
...     Sample(500 * u.ms)
...     >> Delay(1000 * u.ms)
...     >> Test(500 * u.ms)
...     >> If(
...         lambda ctx: ctx['match'],
...         then=MatchResponse(500 * u.ms),
...         else_=NonMatchResponse(500 * u.ms)
...     )
... )
>>> # Go/NoGo with no else branch
>>> phases = (
...     Stimulus(500 * u.ms)
...     >> If(
...         lambda ctx: ctx['is_go'],
...         then=Response(500 * u.ms)
...     )
... )
Parameters:
  • condition (Callable[[Context], bool]) – Function that takes context and returns True/False.

  • then (Phase) – Phase to execute if condition is True.

  • else (Phase, optional) – Phase to execute if condition is False.

  • name (str) – Name for this conditional phase.

children()[source]#

Return the immediate child phases of a compound phase.

Leaf phases return []. Subclasses like Sequence, Repeat, Parallel, If, Switch, While override this so that the Task can traverse the whole tree to bind features.

encode_inputs(ctx)[source]#

Fill ctx.inputs[phase_start:phase_end] with input encoding.

Called once per phase after duration is determined. Must modify ctx.inputs in-place.

Parameters:

ctx (Context) – Context with input buffer and trial state.

Return type:

None

encode_outputs(ctx)[source]#

Fill ctx.outputs[phase_start:phase_end] with target encoding.

Called once per phase after duration is determined. Must modify ctx.outputs in-place.

Parameters:

ctx (Context) – Context with output buffer and trial state.

Return type:

None

execute(ctx)[source]#

Execute the appropriate branch based on condition.

Return type:

None

execute_packed(ctx)[source]#

Packed-mode branch using jax.lax.cond.

Both branches mutate ctx.inputs/outputs/mask/t_cursor. To make the cond functional we (a) snapshot ctx._state so per-branch scratch writes don’t leak across branches, and (b) thread the buffer state through lax.cond as a pytree. Anything not in that pytree (e.g. ctx.phase_history) is best-effort metadata and may contain entries from both branches during tracing — it is not part of the trial’s tensor output.

Return type:

None

get_duration(ctx)[source]#

Duration depends on which branch is taken.

Return type:

int

max_steps(ctx)[source]#

Static upper bound on this phase’s length in timesteps.

Must return a Python int with no dependence on traced values. Used by Task in variable-length mode to size shape-stable buffers. The default delegates to get_duration which is correct for fixed-duration phases. Variable-duration phases (e.g. those wrapping TruncExp/UniformDuration) override this to return the truncation upper bound divided by ctx.dt.

Parameters:

ctx (Context) – A stub or trial context providing ctx.dt. The default implementation does not read ctx.rng or trial state.

Returns:

Upper bound on number of timesteps for this phase.

Return type:

int

step_count(ctx)[source]#

Traced actual length of this phase in timesteps.

Returns a jax.Array int32 scalar. May depend on ctx[...] values populated by trial_init. Must satisfy 0 <= step_count(ctx) <= max_steps(ctx) for every trial.

The default returns a static value equal to get_duration; that is correct for any phase whose actual length matches its upper bound. Variable-duration phases override this to compute the traced length from ctx state without any int(...) cast.

Return type:

Array