Phase#

class braintools.cogtask.Phase(duration, name=None)[source]#

Base class for task phases (epochs/periods).

A phase represents a time interval with specific: - Input encoding rules (how to fill input features) - Output/target encoding rules (what the expected output should be) - Duration (fixed or sampled)

Phases are composable via: - >> operator: sequential concatenation - * operator: repetition - | operator: parallel composition

Examples

>>> # Sequential composition
>>> phases = Fixation(100 * u.ms) >> Stimulus(500 * u.ms) >> Response(100 * u.ms)
>>> # Using concat function
>>> phases = concat([Fixation(100 * u.ms), Stimulus(500 * u.ms)])
>>> # Repetition
>>> repeated = Stimulus(100 * u.ms) * 5  # 5 repetitions
Parameters:
  • duration (Quantity) – Phase duration as Quantity (e.g., 100 * u.ms, 1 * u.second).

  • name (str | None) – Phase name. Defaults to class name.

children()[source]#

Return the immediate child phases of a compound phase.

Leaf phases return []. Subclasses like Sequence, Repeat, Parallel, If, Switch, While override this so that the Task can traverse the whole tree to bind features.

Return type:

List[Phase]

abstractmethod encode_inputs(ctx)[source]#

Fill ctx.inputs[phase_start:phase_end] with input encoding.

Called once per phase after duration is determined. Must modify ctx.inputs in-place.

Parameters:

ctx (Context) – Context with input buffer and trial state.

Return type:

None

abstractmethod encode_outputs(ctx)[source]#

Fill ctx.outputs[phase_start:phase_end] with target encoding.

Called once per phase after duration is determined. Must modify ctx.outputs in-place.

Parameters:

ctx (Context) – Context with output buffer and trial state.

Return type:

None

get_duration(ctx)[source]#

Resolve duration to integer timesteps.

Parameters:

ctx (Context) – Context with dt and rng for duration sampling.

Returns:

Number of timesteps for this phase.

Return type:

int

max_steps(ctx)[source]#

Static upper bound on this phase’s length in timesteps.

Must return a Python int with no dependence on traced values. Used by Task in variable-length mode to size shape-stable buffers. The default delegates to get_duration which is correct for fixed-duration phases. Variable-duration phases (e.g. those wrapping TruncExp/UniformDuration) override this to return the truncation upper bound divided by ctx.dt.

Parameters:

ctx (Context) – A stub or trial context providing ctx.dt. The default implementation does not read ctx.rng or trial state.

Returns:

Upper bound on number of timesteps for this phase.

Return type:

int

on_enter(ctx)[source]#

Hook called when phase begins. Override for setup logic.

Return type:

None

on_exit(ctx)[source]#

Hook called when phase ends. Override for cleanup/state updates.

Return type:

None

step_count(ctx)[source]#

Traced actual length of this phase in timesteps.

Returns a jax.Array int32 scalar. May depend on ctx[...] values populated by trial_init. Must satisfy 0 <= step_count(ctx) <= max_steps(ctx) for every trial.

The default returns a static value equal to get_duration; that is correct for any phase whose actual length matches its upper bound. Variable-duration phases override this to compute the traced length from ctx state without any int(...) cast.

Return type:

Array