Task#
- class braintools.cogtask.Task(phases=None, input_features=None, output_features=None, trial_init=None, name=None, output_mode='categorical', seed=None, **kwargs)[source]#
A cognitive task composed of phases.
The Task class orchestrates phase execution and provides a dataset interface for integration with DataLoaders. Supports both instance-based and class-based definition patterns.
Class-Based Definition#
Subclasses can define tasks by overriding class attributes and methods:
Class attributes: t_fixation, t_sample, t_delay, num_stimuli, etc.
define_features(): Return (input_features, output_features)
define_phases(): Return the phase structure
trial_init(ctx): Initialize trial-level state
Examples
Instance-based (traditional):
>>> task = Task( ... phases=( ... Fixation(100 * u.ms) ... >> Stimulus(2000 * u.ms, feature=stim, encoder=circular_encoder()) ... >> Response(100 * u.ms, output_feature=choice) ... ), ... input_features=fix + stim, ... output_features=fix + choice, ... trial_init=lambda ctx: ctx.update( ... ground_truth=ctx.rng.choice(2), ... direction=ctx.rng.uniform(0, 2*np.pi) ... ) ... )
Class-based (new):
>>> class MyTask(Task): ... t_fixation = 300 * u.ms ... num_stimuli = 8 ... ... def define_features(self): ... fix = Feature(1, 40*u.Hz, 'fixation') ... stim = Feature(self.num_stimuli, 40*u.Hz, 'stimulus') ... return fix + stim, fix + Feature(2, 40*u.Hz, 'response') ... ... def define_phases(self): ... return FixationPhase(self.t_fixation) >> ResponsePhase() ... ... def trial_init(self, ctx): ... ctx['ground_truth'] = ctx.rng.choice(2) ... >>> task = MyTask(num_stimuli=16, seed=42)
- param phases:
The phase structure. If None, uses define_phases() method.
- type phases:
- type input_features:
Feature or FeatureSet, optional
- param input_features:
Input feature definitions. If None, uses define_features() method.
- type input_features:
Feature or FeatureSet, optional
- type output_features:
Feature or FeatureSet, optional
- param output_features:
Output feature definitions. If None, uses define_features() method.
- type output_features:
Feature or FeatureSet, optional
- param trial_init:
Function called at the start of each trial. If None and phases is None, uses the trial_init() method.
- type trial_init:
- param name:
Task name (defaults to class name).
- type name:
- type **kwargs:
- param **kwargs:
Override class attributes (e.g., t_fixation=500*u.ms, num_stimuli=16).
- batch_sample(size, /, time_first=True, return_meta=False, start_index=0, return_mask=False)[source]#
Sample a batch of
sizetrials with indicesstart_index..start_index+size-1.When the task was constructed with
seed=..., each trial in the batch usesjax.random.fold_in(PRNGKey(seed), start_index + i)so callingbatch_samplewith the samestart_indexis reproducible, and successive calls with differentstart_indexproduce non-overlapping batches.- Parameters:
return_mask (
bool) – If True, also return a(T, B)(or(B, T)) boolean mask of valid timesteps. Required for variable-length tasks if you want to know which trailing positions are padding. The mask is always-True for fixed-length tasks.
- define_features()[source]#
Define input and output features.
Override in subclass for class-based task definition.
- define_phases()[source]#
Define the phase structure.
Override in subclass for class-based task definition.
- Returns:
The task phase structure (single phase or composition).
- Return type:
- property is_variable_length: bool#
True if any phase in the tree has
is_variable = True.Variable-length tasks allocate trial buffers of size
max_trial_durationand return a per-timestep mask alongsideX/Yfrombatch_sample()(usereturn_mask=True).
- max_trial_duration(ctx=None)[source]#
Static upper bound on the trial’s timestep count.
For fixed tasks this equals the sum of per-phase
get_durationoutputs. For variable-length tasks each phase contributes itsmax_steps(e.g.ceil(max_duration / dt)forVariableDuration). The result is a Pythonintand is safe to use as a static buffer dimension under JIT/vmap.- Return type:
- sample_trial(index=0, key=None)[source]#
Generate one trial.
- Parameters:
index (
int) – Trial index, made available totrial_initasctx['trial_index']. If the task was constructed withseed=..., the per-trial RNG key isjax.random.fold_in(PRNGKey(seed), index), so reproducibility is keyed on(seed, index).key (
Array|None) – Explicit PRNG key. Overrides the (seed, index) derivation when given.