braintools.cogtask module#

Composable Cognitive Task Framework#

A modular, composable framework for constructing cognitive tasks for neural network training and neuroscience simulations.

Quick Start#

Using pre-built tasks:

>>> from braintools.cogtask import PerceptualDecisionMaking
>>> task = PerceptualDecisionMaking(t_stimulus=2000)
>>> X, Y = task.batch_sample(32)
>>> # train_step(X, Y)

Building custom tasks from phases:

>>> from braintools.cogtask import (
...     Task, Context, concat,
...     Fixation, Stimulus, Delay, Response,
...     Feature, circular, one_hot
... )
>>> import brainunit as u
>>>
>>> # Define features
>>> fix = Feature(1, 'fixation')
>>> stim = Feature(8, 'stimulus')
>>> choice = Feature(2, 'choice')
>>>
>>> # Build task
>>> task = Task(
...     phases=concat([
...         Fixation(100 * u.ms, inputs={'fixation': 1.0}),
...         Stimulus(500 * u.ms, inputs={'stimulus': circular('direction')}),
...         Delay(500 * u.ms, inputs={'fixation': 1.0}),
...         Response(100 * u.ms, outputs={'label': 'ground_truth'})
...     ]),
...     input_features=fix + stim,
...     output_features=fix + choice,
...     trial_init=lambda ctx: ctx.update(
...         ground_truth=ctx.rng.choice(2),
...         direction=ctx.rng.uniform(0, 2*3.14159)
...     )
... )

API Summary#

Core:
  • Task, TaskConfig: Main task class and configuration

  • Context: Inter-phase state container

  • Phase, Sequence, Repeat, Parallel: Phase composition

  • If, Switch, While: Conditional phases

  • concat: Helper for sequential composition

Phases:
  • Fixation, Delay, Stimulus, Response, Cue: Basic phases

  • Sample, Test, Recall, Match, Comparison, Blank: Memory phases

  • DeclarativePhase: Base class for creating custom phases

Features:
  • Feature, FeatureSet, CircleFeature: Input/output encoding

Encoders:
  • circular, one_hot, von_mises, scalar, gaussian, identity, ctx_value

Labels:
  • label, match_label, comparison_label: Output label helpers

Pre-built Tasks:
  • Decision Making: PerceptualDecisionMaking, ContextDecisionMaking, etc.

  • Working Memory: DelayMatchSample, GoNoGo, etc.

  • Reasoning: HierarchicalReasoning, ProbabilisticReasoning

  • Motor: AntiReach, Reaching1D, EvidenceAccumulation

A modular, composable framework for constructing cognitive tasks for neural network training and computational neuroscience simulations.

Overview#

The braintools.cogtask module provides:

  • A phase-based task model that decomposes trials into fixation, stimulus, delay, response, and other epochs with explicit duration, input encoding, and output (target) encoding

  • Composition operators (>>, *, |) and compound phases (Sequence, Repeat, Parallel) for building rich trial structures from simple parts

  • Conditional control flow with If, Switch, and While for trial-by-trial branching and variable-iteration tasks

  • A feature-encoding system that maps trial state into input/output channels via Feature/FeatureSet and value-spec encoders (one_hot, circular, von_mises, gaussian, cos_sin, …)

  • A library of pre-built tasks spanning decision making, working memory, reasoning, and motor control, drawn from systems-neuroscience literature

  • JIT/``vmap``-friendly trial generation through Task.sample() and Task.batch_sample(), designed to integrate cleanly with brainstate and JAX training loops

Core Task Framework#

The Task class orchestrates phase execution, owns the per-trial random key, and exposes the dataset-style sample/batch_sample API. Context is the mutable trial-level state container shared across phases; it carries the RNG, input/output buffers, timing information, and trial-level user data.

Task

A cognitive task composed of phases.

Context

Mutable trial-level state container shared across phases.

Two equivalent ways to define a task are supported:

Phases and Composition#

Phases are the atomic units of a trial. Phase is the abstract base class; concrete declarative phases (Fixation, Stimulus, Delay, Response, …) inherit from DeclarativePhase and describe their inputs/outputs/noise via dictionaries instead of code.

Phases compose with operators:

  • a >> b — sequential composition (yields Sequence)

  • a * n — repeat n times (yields Repeat)

  • a | b — parallel composition (yields Parallel)

  • concat() — sequence from a list

Base Class#

Phase

Base class for task phases (epochs/periods).

DeclarativePhase

Declarative phase definition with explicit input/output specifications.

Compound Phases#

Sequence

Sequential composition of phases.

Repeat

Repeat a phase N times.

Parallel

Parallel composition - phases execute simultaneously.

Declarative Phase Types#

These are convenience subclasses of DeclarativePhase that share its interface but provide semantic names so trial structures read naturally. They differ only in identity — use whichever name best describes the epoch.

Basic epochs:

Working-memory epochs:

Variable-length epochs:

VariableDuration

Declarative phase whose actual length is decided per trial.

Composition Helpers#

concat

Concatenate phases into a sequence.

execute_phase

Execute a single phase, updating context appropriately.

execute_phase_packed

Variable-length / packed-mode dispatch for a single phase.

phase_tree_is_variable

Walk a phase tree and return True if any node declares is_variable = True.

Conditional Phases#

Phases can branch on trial state at runtime. If selects between then / else_; Switch dispatches over many cases; While loops until a condition fails (bounded by max_iterations).

If

Conditional phase selection based on a boolean condition.

Switch

Multi-way conditional phase selection.

While

Loop phase while condition is true.

Because these phases inspect trial state during a Python-level pass over the tree, the branch they take must be derivable from values set in trial_init (or in earlier phases’ on_exit hooks). Their total duration, summed across iterations or branches, contributes to the per-trial buffer size — see Tutorial 3: Variable-length trial sequences for the implications.

Features#

A Feature declares one logical input or output channel of the task, with a fixed dimensionality and a name. FeatureSet collects features into a flat vector and tracks per-feature index slices automatically. CircleFeature adds a value range for angular / directional outputs.

Compose features with + (concatenate, immutable), - (remove by name), | (alias for +), and *n (named repetition).

Feature

Individual feature encoder for cognitive task inputs/outputs.

FeatureSet

Collection of features with automatic index management.

CircleFeature

Circular feature for angular/directional data.

Feature predicates:

is_feature

Return True iff x is a Feature instance.

as_feature

Return x if it is a Feature; otherwise raise TypeError.

Encoders#

Encoders are value specifications — callables of the form f(ctx, feature) -> jnp.ndarray that DeclarativePhase uses to fill its input slice for one feature. They translate trial-level state (e.g. a direction angle, a discrete class index) into per-timestep input activations.

Discrete / class encoders:

one_hot

Create a one-hot encoding value specification.

identity

Create an identity encoding that passes through values directly.

Directional / population encoders:

circular

Cosine-tuned directional encoder.

von_mises

Von Mises (circular-normal) directional encoder.

cos_sin

Encode a discrete direction index into repeated [cos(theta), sin(theta)] features.

Scalar / shape encoders:

scalar

Create a scalar encoding that broadcasts to all feature dimensions.

gaussian

Create a Gaussian bump encoding value specification.

ctx_value

Create a dynamic value that reads directly from context.

Output Labels#

Label helpers build the outputs={'label': ...} spec used by phases in categorical mode. They convert per-trial state into integer labels, time- varying label arrays, or match/comparison codes.

label

Create an output label specification.

match_label

Create a label for match/non-match tasks.

comparison_label

Create a label for comparison tasks.

Pre-built Tasks#

The cogtask package ships a library of standard cognitive paradigms, each implemented as a subclass of Task that defines its own features, phase structure, and trial-init logic. Construct them like any other Task, optionally passing seed= for reproducibility — see Tutorial 1: Quickstart with braintools.cogtask for a runnable example.

Decision Making#

Two-alternative and multi-modal perceptual decision tasks with motion coherence, context cues, or discrete evidence pulses.

PerceptualDecisionMaking

Perceptual Decision Making (PDM) task.

PerceptualDecisionMakingDelayResponse

PDM task with delay before response.

ContextDecisionMaking

Context-Dependent Decision Making task.

SingleContextDecisionMaking

Single-Context Decision Making task.

PulseDecisionMaking

Pulse-Based Decision Making task.

Working Memory#

Delay-bridging tasks that require holding stimulus identity, magnitude, category, direction, or interval information across a memory period.

DelayMatchSample

Delayed Match-to-Sample (DMS) task.

DualDelayMatchSample

Dual Delayed Match-to-Sample task.

DelayComparison

Delayed Comparison task.

DelayMatchCategory

Delayed Match-to-Category task.

DelayPairedAssociation

Delayed Paired Association task.

GoNoGo

Go/No-Go task.

IntervalDiscrimination

Interval Discrimination task.

PostDecisionWager

Post-Decision Wager task.

ReadySetGo

Ready-Set-Go timing task.

DelayDirectionReproduction

Delay Direction Reproduction task.

ImmediateDirectionReproduction

Immediate Direction Reproduction task.

DelayDirectionClassification

Delayed Direction Classification (DDC).

ImmediateDirectionClassification

Immediate Direction Classification (IDC).

Reasoning#

Tasks that require integrating multiple cues or rules to arrive at a decision under uncertainty.

HierarchicalReasoning

Hierarchical Reasoning task.

ProbabilisticReasoning

Probabilistic Reasoning task.

Motor#

Reaching, anti-reaching, and evidence-accumulation tasks that produce continuous motor outputs.

AntiReach

Anti-Reach (Anti-Saccade) task.

Reaching1D

1D Reaching task.

EvidenceAccumulation

Evidence Accumulation task.

Utilities#

Duration distributions for sampling variable-length phases:

TruncExp

Truncated exponential distribution for sampling time durations.

UniformDuration

Uniform distribution for sampling time durations.

Dataset transforms applied around Task.batch_sample():

Transform

Base class for dataset transformations.

TransformIT

Transformation which transforms input and target separately.

Helper functions for periods, label arrays, and rate conversion:

initialize

Initialize/resolve a parameter value.

initialize2

Initialize/resolve a parameter and convert to timesteps.

interval_of

Get slice for a named period in a sequence of periods.

period_to_arr

Convert period dictionary to label array.

firing_rate

Convert firing rate to spike probability per timestep.

Concepts#

Trial generation#

A Task produces one trial as follows:

  1. Construct a Context, seeded by jax.random.fold_in(seed, index) when Task was given a seed.

  2. Call trial_init(ctx) (or Task.trial_init() for subclasses) to populate trial-level state — ground truth, stimulus identity, coherence, etc.

  3. Compute total duration with a dry-run pass over the phase tree (so variable-duration phases can read state set by trial_init).

  4. Allocate ctx.inputs of shape (T, num_inputs) and ctx.outputs either (T,) (categorical mode) or (T, num_outputs) (vector mode).

  5. Walk the phase tree a second time, with each phase calling Phase.encode_inputs() and Phase.encode_outputs() to fill its slice of the buffers.

Task.batch_sample() vmap s this process, producing batches whose keys differ by fold_in of the trial index so batches are reproducible.

Sampling APIs and tensor shapes#

A configured Task exposes three sampling entry points. The shapes below assume num_inputs == task.num_inputs and num_outputs == task.num_outputs; T is the per-trial timestep count.

Method

Returns

Shapes

Task.sample_trial(index)()

(X, Y, info)

X: (T, num_inputs), Y: (T,) or (T, num_outputs)

Task.sample(index)() / task[index]

(X, Y)

same as above (JIT-compiled)

Task.batch_sample(B)()

(X, Y)

X: (T, B, num_inputs), Y: (T, B) or (T, B, num_outputs)

Task.batch_sample(B, time_first=False)()

(X, Y)

X: (B, T, num_inputs), Y: (B, T) or (B, T, num_outputs)

Task.batch_sample(B, return_meta=True)()

(X, Y, meta)

as above; meta is task-defined

The third value returned from Task.sample_trial() is a dictionary with the following keys:

  • phase_history — list of (name, start, end) tuples logging each phase’s contribution to the timeline

  • trial_state — copy of the user state set via trial_init (e.g. ground_truth, coherence)

  • dt — the resolved time step (from brainstate.environ.get_dt())

  • index — the trial index requested

To customize the metadata returned by batch_sample(..., return_meta=True), override Task.get_trial_meta() in your subclass.

Variable-length trials#

A phase tree is variable-length when any phase advertises is_variable = True. The built-in cases are VariableDuration (its duration is read from ctx[ctx_key]), If, Switch, and While. phase_tree_is_variable() walks the tree to detect this at construction time; Task records the result as task.is_variable_length and routes such trials through the packed sample path.

Packed-mode semantics:

  • task.max_trial_duration() returns a Python int upper bound, computed by walking each phase’s max_steps. This value is safe to use as a static buffer dimension under brainstate.transform.jit and brainstate.transform.vmap2.

  • sample_trial allocates X of shape (max_trial_duration, num_inputs), Y of shape (max_trial_duration,) (categorical) or (max_trial_duration, num_outputs) (vector), and mask of shape (max_trial_duration,).

  • Each phase contributes step_count valid timesteps starting at ctx.t_cursor. The phase’s writes are gated to the active region, trailing positions remain zero, and the mask is set to True only over the actual valid range.

  • batch_sample(B, return_mask=True) returns (X, Y, mask) with the mask in the same time/batch layout as X/Y. Fixed-length tasks also support return_mask=True; their mask is all-True.

Example:

import brainunit as u
import jax.numpy as jnp
from braintools.cogtask import (
    Task, Feature, Fixation, Stimulus, Response,
    VariableDuration, concat,
)

fix = Feature(1, 'fixation')
stim = Feature(2, 'stim')
choice = Feature(2, 'choice')

phases = concat([
    Fixation(50 * u.ms, inputs={'fixation': 1.0}),
    VariableDuration(
        min_duration=200 * u.ms,
        max_duration=1500 * u.ms,
        ctx_key='delay',
        inputs={'fixation': 1.0},
    ),
    Stimulus(100 * u.ms, inputs={'stim': lambda c, f: jnp.ones(f.num)}),
    Response(50 * u.ms, outputs={'label': lambda c, f: c['gt']}),
])

def init(ctx):
    ctx['delay'] = ctx.rng.uniform(200.0, 1500.0)
    ctx['gt'] = ctx.rng.choice(2)

task = Task(
    phases=phases,
    input_features=fix + stim,
    output_features=fix + choice,
    trial_init=init,
    seed=0,
)

assert task.is_variable_length
X, Y, mask = task.batch_sample(64, return_mask=True)
# X, Y, mask all share their time dimension == task.max_trial_duration()

Conditional compounds (If, Switch, While) work in packed mode too. If uses jax.lax.cond so both branches contribute shape-stable output even when the predicate is a tracer. Switch and While use Python-level dispatch: the selector must return a hashable key (not a tracer), and While’s condition must return a Python bool. Branches that don’t run leave the buffer slot at zero and don’t advance t_cursor.

Output modes#

  • 'categorical' (default): ctx.outputs has shape (T,) and holds integer labels. Phases set the 'label' key in their outputs= dict.

  • 'vector': ctx.outputs has shape (T, num_outputs). Phases set each output feature by name (e.g. 'direction', 'fixation_out'). Use this for continuous-report tasks such as DelayDirectionReproduction.

Declarative phase shape conventions#

A value spec for inputs= can be a constant or a callable f(ctx, feature) -> array. The encoded value is broadcast into the phase’s slice of ctx.inputs according to its shape:

  • scalar → constant for every timestep and feature unit

  • 1-D, shape (feature.num,) → broadcast along the time axis

  • 2-D, shape (duration, feature.num) → written directly

For outputs= the conventions depend on the output mode:

  • Categorical (ctx.outputs.ndim == 1): use the 'label' key. Accepts a scalar (constant label for the phase) or a 1-D array of shape (duration,) (time-varying labels). Features other than 'label' are ignored.

  • Vector (ctx.outputs.ndim == 2): write per-output-feature; accept (feature.num,) (broadcast along time) or (duration, feature.num).

The noise= field maps a feature name to a sigma Quantity in units of ms**0.5. Noise is sampled fresh per phase and scaled by 1 / sqrt(dt) so the resulting signal variance is invariant under changes of dt.

Feature index management#

Features expose .i, a Python slice into the flat input/output vector. When a feature is composed via a + b, its _start/_end are shifted automatically, so phase encoders can write ctx.inputs[..., feat.i] without bookkeeping. Composition is immutable: both operands are copied before being shifted.

Reproducibility#

A Task constructed with seed=N derives each trial’s key as jax.random.fold_in(jax.random.PRNGKey(N), trial_index). This makes task.sample(i) deterministic and task.batch_sample(B, start_index=k) reproducible and non-overlapping across calls. If seed is omitted, trials draw fresh randomness from brainstate’s default RNG.

Time step#

All durations are resolved against the currently active time step, brainstate.environ.get_dt(). The same task can be re-sampled at a finer or coarser dt simply by wrapping it in a brainstate.environ.context; see Tutorial 1: Quickstart with braintools.cogtask for a worked example.