braintools.cogtask module#
Composable Cognitive Task Framework#
A modular, composable framework for constructing cognitive tasks for neural network training and neuroscience simulations.
Quick Start#
Using pre-built tasks:
>>> from braintools.cogtask import PerceptualDecisionMaking
>>> task = PerceptualDecisionMaking(t_stimulus=2000)
>>> X, Y = task.batch_sample(32)
>>> # train_step(X, Y)
Building custom tasks from phases:
>>> from braintools.cogtask import (
... Task, Context, concat,
... Fixation, Stimulus, Delay, Response,
... Feature, circular, one_hot
... )
>>> import brainunit as u
>>>
>>> # Define features
>>> fix = Feature(1, 'fixation')
>>> stim = Feature(8, 'stimulus')
>>> choice = Feature(2, 'choice')
>>>
>>> # Build task
>>> task = Task(
... phases=concat([
... Fixation(100 * u.ms, inputs={'fixation': 1.0}),
... Stimulus(500 * u.ms, inputs={'stimulus': circular('direction')}),
... Delay(500 * u.ms, inputs={'fixation': 1.0}),
... Response(100 * u.ms, outputs={'label': 'ground_truth'})
... ]),
... input_features=fix + stim,
... output_features=fix + choice,
... trial_init=lambda ctx: ctx.update(
... ground_truth=ctx.rng.choice(2),
... direction=ctx.rng.uniform(0, 2*3.14159)
... )
... )
API Summary#
- Core:
Task, TaskConfig: Main task class and configuration
Context: Inter-phase state container
Phase, Sequence, Repeat, Parallel: Phase composition
If, Switch, While: Conditional phases
concat: Helper for sequential composition
- Phases:
Fixation, Delay, Stimulus, Response, Cue: Basic phases
Sample, Test, Recall, Match, Comparison, Blank: Memory phases
DeclarativePhase: Base class for creating custom phases
- Features:
Feature, FeatureSet, CircleFeature: Input/output encoding
- Encoders:
circular, one_hot, von_mises, scalar, gaussian, identity, ctx_value
- Labels:
label, match_label, comparison_label: Output label helpers
- Pre-built Tasks:
Decision Making: PerceptualDecisionMaking, ContextDecisionMaking, etc.
Working Memory: DelayMatchSample, GoNoGo, etc.
Reasoning: HierarchicalReasoning, ProbabilisticReasoning
Motor: AntiReach, Reaching1D, EvidenceAccumulation
A modular, composable framework for constructing cognitive tasks for neural network training and computational neuroscience simulations.
See also
For runnable, narrative walkthroughs, see the tutorials:
Overview#
The braintools.cogtask module provides:
A phase-based task model that decomposes trials into fixation, stimulus, delay, response, and other epochs with explicit duration, input encoding, and output (target) encoding
Composition operators (
>>,*,|) and compound phases (Sequence,Repeat,Parallel) for building rich trial structures from simple partsConditional control flow with
If,Switch, andWhilefor trial-by-trial branching and variable-iteration tasksA feature-encoding system that maps trial state into input/output channels via
Feature/FeatureSetand value-spec encoders (one_hot,circular,von_mises,gaussian,cos_sin, …)A library of pre-built tasks spanning decision making, working memory, reasoning, and motor control, drawn from systems-neuroscience literature
JIT/``vmap``-friendly trial generation through
Task.sample()andTask.batch_sample(), designed to integrate cleanly with brainstate and JAX training loops
Core Task Framework#
The Task class orchestrates phase execution, owns the per-trial
random key, and exposes the dataset-style sample/batch_sample API.
Context is the mutable trial-level state container shared across
phases; it carries the RNG, input/output buffers, timing information, and
trial-level user data.
A cognitive task composed of phases. |
|
Mutable trial-level state container shared across phases. |
Two equivalent ways to define a task are supported:
Instance-based: pass
phases=,input_features=,output_features=, andtrial_init=directly toTask. Best for one-off tasks or interactive exploration.Class-based: subclass
Taskand overrideTask.define_features(),Task.define_phases(), andTask.trial_init(). Best for reusable, parameterized tasks — all pre-built tasks follow this pattern. See Tutorial 2: Building custom tasks for worked examples of both.
Phases and Composition#
Phases are the atomic units of a trial. Phase is the abstract base
class; concrete declarative phases (Fixation, Stimulus,
Delay, Response, …) inherit from DeclarativePhase
and describe their inputs/outputs/noise via dictionaries instead of code.
Phases compose with operators:
a >> b— sequential composition (yieldsSequence)a * n— repeatntimes (yieldsRepeat)a | b— parallel composition (yieldsParallel)concat()— sequence from a list
Base Class#
Base class for task phases (epochs/periods). |
|
Declarative phase definition with explicit input/output specifications. |
Compound Phases#
Declarative Phase Types#
These are convenience subclasses of DeclarativePhase that share its
interface but provide semantic names so trial structures read naturally.
They differ only in identity — use whichever name best describes the epoch.
Basic epochs:
Working-memory epochs:
Variable-length epochs:
Declarative phase whose actual length is decided per trial. |
Composition Helpers#
Concatenate phases into a sequence. |
|
Execute a single phase, updating context appropriately. |
|
Variable-length / packed-mode dispatch for a single phase. |
|
Walk a phase tree and return True if any node declares |
Conditional Phases#
Phases can branch on trial state at runtime. If selects between
then / else_; Switch dispatches over many cases; While
loops until a condition fails (bounded by max_iterations).
Conditional phase selection based on a boolean condition. |
|
Multi-way conditional phase selection. |
|
Loop phase while condition is true. |
Because these phases inspect trial state during a Python-level pass over the
tree, the branch they take must be derivable from values set in
trial_init (or in earlier phases’ on_exit hooks). Their total
duration, summed across iterations or branches, contributes to the per-trial
buffer size — see Tutorial 3: Variable-length trial sequences for the
implications.
Features#
A Feature declares one logical input or output channel of the task,
with a fixed dimensionality and a name. FeatureSet collects features
into a flat vector and tracks per-feature index slices automatically.
CircleFeature adds a value range for angular / directional outputs.
Compose features with + (concatenate, immutable), - (remove by name),
| (alias for +), and *n (named repetition).
Individual feature encoder for cognitive task inputs/outputs. |
|
Collection of features with automatic index management. |
|
Circular feature for angular/directional data. |
Feature predicates:
Encoders#
Encoders are value specifications — callables of the form
f(ctx, feature) -> jnp.ndarray that DeclarativePhase uses to fill
its input slice for one feature. They translate trial-level state (e.g. a
direction angle, a discrete class index) into per-timestep input activations.
Discrete / class encoders:
Create a one-hot encoding value specification. |
|
Create an identity encoding that passes through values directly. |
Directional / population encoders:
Cosine-tuned directional encoder. |
|
Von Mises (circular-normal) directional encoder. |
|
Encode a discrete direction index into repeated [cos(theta), sin(theta)] features. |
Scalar / shape encoders:
Output Labels#
Label helpers build the outputs={'label': ...} spec used by phases in
categorical mode. They convert per-trial state into integer labels, time-
varying label arrays, or match/comparison codes.
Create an output label specification. |
|
Create a label for match/non-match tasks. |
|
Create a label for comparison tasks. |
Pre-built Tasks#
The cogtask package ships a library of standard cognitive paradigms,
each implemented as a subclass of Task that defines its own
features, phase structure, and trial-init logic. Construct them like any
other Task, optionally passing seed= for reproducibility — see
Tutorial 1: Quickstart with braintools.cogtask for a runnable example.
Decision Making#
Two-alternative and multi-modal perceptual decision tasks with motion coherence, context cues, or discrete evidence pulses.
Perceptual Decision Making (PDM) task. |
|
PDM task with delay before response. |
|
Context-Dependent Decision Making task. |
|
Single-Context Decision Making task. |
|
Pulse-Based Decision Making task. |
Working Memory#
Delay-bridging tasks that require holding stimulus identity, magnitude, category, direction, or interval information across a memory period.
Delayed Match-to-Sample (DMS) task. |
|
Dual Delayed Match-to-Sample task. |
|
Delayed Comparison task. |
|
Delayed Match-to-Category task. |
|
Delayed Paired Association task. |
|
Go/No-Go task. |
|
Interval Discrimination task. |
|
Post-Decision Wager task. |
|
Ready-Set-Go timing task. |
|
Delay Direction Reproduction task. |
|
Immediate Direction Reproduction task. |
|
Delayed Direction Classification (DDC). |
|
Immediate Direction Classification (IDC). |
Reasoning#
Tasks that require integrating multiple cues or rules to arrive at a decision under uncertainty.
Hierarchical Reasoning task. |
|
Probabilistic Reasoning task. |
Motor#
Reaching, anti-reaching, and evidence-accumulation tasks that produce continuous motor outputs.
Anti-Reach (Anti-Saccade) task. |
|
1D Reaching task. |
|
Evidence Accumulation task. |
Utilities#
Duration distributions for sampling variable-length phases:
Truncated exponential distribution for sampling time durations. |
|
Uniform distribution for sampling time durations. |
Dataset transforms applied around Task.batch_sample():
Base class for dataset transformations. |
|
Transformation which transforms input and target separately. |
Helper functions for periods, label arrays, and rate conversion:
Initialize/resolve a parameter value. |
|
Initialize/resolve a parameter and convert to timesteps. |
|
Get slice for a named period in a sequence of periods. |
|
Convert period dictionary to label array. |
|
Convert firing rate to spike probability per timestep. |
Concepts#
Trial generation#
A Task produces one trial as follows:
Construct a
Context, seeded byjax.random.fold_in(seed, index)whenTaskwas given aseed.Call
trial_init(ctx)(orTask.trial_init()for subclasses) to populate trial-level state — ground truth, stimulus identity, coherence, etc.Compute total duration with a dry-run pass over the phase tree (so variable-duration phases can read state set by
trial_init).Allocate
ctx.inputsof shape(T, num_inputs)andctx.outputseither(T,)(categorical mode) or(T, num_outputs)(vector mode).Walk the phase tree a second time, with each phase calling
Phase.encode_inputs()andPhase.encode_outputs()to fill its slice of the buffers.
Task.batch_sample() vmap s this process, producing batches whose
keys differ by fold_in of the trial index so batches are reproducible.
Sampling APIs and tensor shapes#
A configured Task exposes three sampling entry points. The shapes
below assume num_inputs == task.num_inputs and num_outputs ==
task.num_outputs; T is the per-trial timestep count.
Method |
Returns |
Shapes |
|---|---|---|
|
|
|
|
|
same as above (JIT-compiled) |
|
|
|
|
|
|
|
|
as above; |
The third value returned from Task.sample_trial() is a dictionary
with the following keys:
phase_history— list of(name, start, end)tuples logging each phase’s contribution to the timelinetrial_state— copy of the user state set viatrial_init(e.g.ground_truth,coherence)dt— the resolved time step (frombrainstate.environ.get_dt())index— the trial index requested
To customize the metadata returned by batch_sample(..., return_meta=True),
override Task.get_trial_meta() in your subclass.
Variable-length trials#
A phase tree is variable-length when any phase advertises
is_variable = True. The built-in cases are VariableDuration
(its duration is read from ctx[ctx_key]), If, Switch,
and While. phase_tree_is_variable() walks the tree to
detect this at construction time; Task records the result as
task.is_variable_length and routes such trials through the packed
sample path.
Packed-mode semantics:
task.max_trial_duration()returns a Pythonintupper bound, computed by walking each phase’smax_steps. This value is safe to use as a static buffer dimension underbrainstate.transform.jitandbrainstate.transform.vmap2.sample_trialallocatesXof shape(max_trial_duration, num_inputs),Yof shape(max_trial_duration,)(categorical) or(max_trial_duration, num_outputs)(vector), andmaskof shape(max_trial_duration,).Each phase contributes
step_countvalid timesteps starting atctx.t_cursor. The phase’s writes are gated to the active region, trailing positions remain zero, and the mask is set toTrueonly over the actual valid range.batch_sample(B, return_mask=True)returns(X, Y, mask)with the mask in the same time/batch layout asX/Y. Fixed-length tasks also supportreturn_mask=True; their mask is all-True.
Example:
import brainunit as u
import jax.numpy as jnp
from braintools.cogtask import (
Task, Feature, Fixation, Stimulus, Response,
VariableDuration, concat,
)
fix = Feature(1, 'fixation')
stim = Feature(2, 'stim')
choice = Feature(2, 'choice')
phases = concat([
Fixation(50 * u.ms, inputs={'fixation': 1.0}),
VariableDuration(
min_duration=200 * u.ms,
max_duration=1500 * u.ms,
ctx_key='delay',
inputs={'fixation': 1.0},
),
Stimulus(100 * u.ms, inputs={'stim': lambda c, f: jnp.ones(f.num)}),
Response(50 * u.ms, outputs={'label': lambda c, f: c['gt']}),
])
def init(ctx):
ctx['delay'] = ctx.rng.uniform(200.0, 1500.0)
ctx['gt'] = ctx.rng.choice(2)
task = Task(
phases=phases,
input_features=fix + stim,
output_features=fix + choice,
trial_init=init,
seed=0,
)
assert task.is_variable_length
X, Y, mask = task.batch_sample(64, return_mask=True)
# X, Y, mask all share their time dimension == task.max_trial_duration()
Conditional compounds (If, Switch, While) work
in packed mode too. If uses jax.lax.cond so both branches
contribute shape-stable output even when the predicate is a tracer.
Switch and While use Python-level dispatch: the
selector must return a hashable key (not a tracer), and While’s
condition must return a Python bool. Branches that don’t run leave
the buffer slot at zero and don’t advance t_cursor.
Output modes#
'categorical'(default):ctx.outputshas shape(T,)and holds integer labels. Phases set the'label'key in theiroutputs=dict.'vector':ctx.outputshas shape(T, num_outputs). Phases set each output feature by name (e.g.'direction','fixation_out'). Use this for continuous-report tasks such asDelayDirectionReproduction.
Declarative phase shape conventions#
A value spec for inputs= can be a constant or a callable
f(ctx, feature) -> array. The encoded value is broadcast into the
phase’s slice of ctx.inputs according to its shape:
scalar → constant for every timestep and feature unit
1-D, shape
(feature.num,)→ broadcast along the time axis2-D, shape
(duration, feature.num)→ written directly
For outputs= the conventions depend on the output mode:
Categorical (
ctx.outputs.ndim == 1): use the'label'key. Accepts a scalar (constant label for the phase) or a 1-D array of shape(duration,)(time-varying labels). Features other than'label'are ignored.Vector (
ctx.outputs.ndim == 2): write per-output-feature; accept(feature.num,)(broadcast along time) or(duration, feature.num).
The noise= field maps a feature name to a sigma Quantity in units of
ms**0.5. Noise is sampled fresh per phase and scaled by
1 / sqrt(dt) so the resulting signal variance is invariant under changes
of dt.
Feature index management#
Features expose .i, a Python slice into the flat input/output vector.
When a feature is composed via a + b, its _start/_end are
shifted automatically, so phase encoders can write ctx.inputs[..., feat.i]
without bookkeeping. Composition is immutable: both operands are copied
before being shifted.
Reproducibility#
A Task constructed with seed=N derives each trial’s key as
jax.random.fold_in(jax.random.PRNGKey(N), trial_index). This makes
task.sample(i) deterministic and task.batch_sample(B, start_index=k)
reproducible and non-overlapping across calls. If seed is omitted, trials
draw fresh randomness from brainstate’s default RNG.
Time step#
All durations are resolved against the currently active time step,
brainstate.environ.get_dt(). The same task can be re-sampled at a finer
or coarser dt simply by wrapping it in a brainstate.environ.context;
see Tutorial 1: Quickstart with braintools.cogtask for a worked example.