Task#

class braintools.cogtask.Task(phases=None, input_features=None, output_features=None, trial_init=None, name=None, output_mode='categorical', seed=None, **kwargs)[source]#

A cognitive task composed of phases.

The Task class orchestrates phase execution and provides a dataset interface for integration with DataLoaders. Supports both instance-based and class-based definition patterns.

Class-Based Definition#

Subclasses can define tasks by overriding class attributes and methods:

  • Class attributes: t_fixation, t_sample, t_delay, num_stimuli, etc.

  • define_features(): Return (input_features, output_features)

  • define_phases(): Return the phase structure

  • trial_init(ctx): Initialize trial-level state

Examples

Instance-based (traditional):

>>> task = Task(
...     phases=(
...         Fixation(100 * u.ms)
...         >> Stimulus(2000 * u.ms, feature=stim, encoder=circular_encoder())
...         >> Response(100 * u.ms, output_feature=choice)
...     ),
...     input_features=fix + stim,
...     output_features=fix + choice,
...     trial_init=lambda ctx: ctx.update(
...         ground_truth=ctx.rng.choice(2),
...         direction=ctx.rng.uniform(0, 2*np.pi)
...     )
... )

Class-based (new):

>>> class MyTask(Task):
...     t_fixation = 300 * u.ms
...     num_stimuli = 8
...
...     def define_features(self):
...         fix = Feature(1, 40*u.Hz, 'fixation')
...         stim = Feature(self.num_stimuli, 40*u.Hz, 'stimulus')
...         return fix + stim, fix + Feature(2, 40*u.Hz, 'response')
...
...     def define_phases(self):
...         return FixationPhase(self.t_fixation) >> ResponsePhase()
...
...     def trial_init(self, ctx):
...         ctx['ground_truth'] = ctx.rng.choice(2)
...
>>> task = MyTask(num_stimuli=16, seed=42)
param phases:

The phase structure. If None, uses define_phases() method.

type phases:

Phase

type input_features:

Feature or FeatureSet, optional

param input_features:

Input feature definitions. If None, uses define_features() method.

type input_features:

Feature or FeatureSet, optional

type output_features:

Feature or FeatureSet, optional

param output_features:

Output feature definitions. If None, uses define_features() method.

type output_features:

Feature or FeatureSet, optional

param trial_init:

Function called at the start of each trial. If None and phases is None, uses the trial_init() method.

type trial_init:

Callable[[Context], None] | None

param name:

Task name (defaults to class name).

type name:

str | None

type **kwargs:

param **kwargs:

Override class attributes (e.g., t_fixation=500*u.ms, num_stimuli=16).

batch_sample(size, /, time_first=True, return_meta=False, start_index=0, return_mask=False)[source]#

Sample a batch of size trials with indices start_index..start_index+size-1.

When the task was constructed with seed=..., each trial in the batch uses jax.random.fold_in(PRNGKey(seed), start_index + i) so calling batch_sample with the same start_index is reproducible, and successive calls with different start_index produce non-overlapping batches.

Parameters:

return_mask (bool) – If True, also return a (T, B) (or (B, T)) boolean mask of valid timesteps. Required for variable-length tasks if you want to know which trailing positions are padding. The mask is always-True for fixed-length tasks.

define_features()[source]#

Define input and output features.

Override in subclass for class-based task definition.

Return type:

Tuple[Any, Any]

Returns:

  • input_features (Feature or FeatureSet) – Input feature definitions.

  • output_features (Feature or FeatureSet) – Output feature definitions.

define_phases()[source]#

Define the phase structure.

Override in subclass for class-based task definition.

Returns:

The task phase structure (single phase or composition).

Return type:

Phase

property is_variable_length: bool#

True if any phase in the tree has is_variable = True.

Variable-length tasks allocate trial buffers of size max_trial_duration and return a per-timestep mask alongside X/Y from batch_sample() (use return_mask=True).

max_trial_duration(ctx=None)[source]#

Static upper bound on the trial’s timestep count.

For fixed tasks this equals the sum of per-phase get_duration outputs. For variable-length tasks each phase contributes its max_steps (e.g. ceil(max_duration / dt) for VariableDuration). The result is a Python int and is safe to use as a static buffer dimension under JIT/vmap.

Return type:

int

sample_trial(index=0, key=None)[source]#

Generate one trial.

Parameters:
  • index (int) – Trial index, made available to trial_init as ctx['trial_index']. If the task was constructed with seed=..., the per-trial RNG key is jax.random.fold_in(PRNGKey(seed), index), so reproducibility is keyed on (seed, index).

  • key (Array | None) – Explicit PRNG key. Overrides the (seed, index) derivation when given.

trial_init(ctx)[source]#

Initialize trial-level state.

Override in subclass to set up trial parameters like ground_truth, stimulus indices, etc.

Parameters:

ctx (Context) – Trial context to populate with state.

Return type:

None