Switch#
- class braintools.cogtask.Switch(selector, cases, default=None, name='Switch')[source]#
Multi-way conditional phase selection.
Evaluates the selector function to get a key, then executes the corresponding phase from the cases dictionary.
Examples
>>> # Rule-dependent response >>> phases = ( ... Stimulus(500 * u.ms) ... >> Delay(1000 * u.ms) ... >> Switch( ... lambda ctx: ctx['rule'], ... cases={ ... 'pro': ProResponse(500 * u.ms), ... 'anti': AntiResponse(500 * u.ms), ... }, ... default=DefaultResponse(500 * u.ms) ... ) ... )
>>> # Multiple choice selection >>> phases = Switch( ... lambda ctx: ctx['choice'], ... cases={ ... 0: Choice0Response(100 * u.ms), ... 1: Choice1Response(100 * u.ms), ... 2: Choice2Response(100 * u.ms), ... } ... )
- Parameters:
- children()[source]#
Return the immediate child phases of a compound phase.
Leaf phases return
[]. Subclasses likeSequence,Repeat,Parallel,If,Switch,Whileoverride this so that theTaskcan traverse the whole tree to bind features.
- encode_inputs(ctx)[source]#
Fill ctx.inputs[phase_start:phase_end] with input encoding.
Called once per phase after duration is determined. Must modify ctx.inputs in-place.
- encode_outputs(ctx)[source]#
Fill ctx.outputs[phase_start:phase_end] with target encoding.
Called once per phase after duration is determined. Must modify ctx.outputs in-place.
- execute_packed(ctx)[source]#
Packed-mode dispatch.
Selects a branch using
self.selector(ctx)and runs it viaexecute_phase_packed(). The selector must return a hashable Python value (string, int, …) — traced selectors would require alax.switchbased dispatch, which is not currently implemented.- Return type:
- max_steps(ctx)[source]#
Static upper bound on this phase’s length in timesteps.
Must return a Python int with no dependence on traced values. Used by
Taskin variable-length mode to size shape-stable buffers. The default delegates toget_durationwhich is correct for fixed-duration phases. Variable-duration phases (e.g. those wrappingTruncExp/UniformDuration) override this to return the truncation upper bound divided byctx.dt.
- step_count(ctx)[source]#
Traced actual length of this phase in timesteps.
Returns a
jax.Arrayint32scalar. May depend onctx[...]values populated bytrial_init. Must satisfy0 <= step_count(ctx) <= max_steps(ctx)for every trial.The default returns a static value equal to
get_duration; that is correct for any phase whose actual length matches its upper bound. Variable-duration phases override this to compute the traced length fromctxstate without anyint(...)cast.- Return type:
Array