Tutorial 2: Building custom tasks#
This tutorial covers the building blocks you assemble into a custom
braintools.cogtask.Task:
Features —
Feature/FeatureSet, composition, and the.islice.Declarative phases —
Fixation,Stimulus,Delay,Response, …Encoders — value specs like
one_hot,circular,von_mises,cos_sin,gaussian,scalar,identity,ctx_value.Labels —
label,match_label,comparison_label.Putting it together — both the instance-based and the class-based patterns.
Vector outputs — using
output_mode='vector'for continuous-report tasks.Branching with
If/Switch/While.Custom encoders.
import brainunit as u
import jax.numpy as jnp
import brainstate
brainstate.environ.set(dt=1.0 * u.ms)
from braintools.cogtask import (
Task, Context, concat,
Phase, Sequence, Repeat, Parallel,
Fixation, Stimulus, Delay, Response, Cue, Blank,
Sample, Test, Recall, Match, Comparison,
If, Switch, While,
Feature, FeatureSet, CircleFeature,
one_hot, circular, von_mises, cos_sin, gaussian, scalar, identity, ctx_value,
label, match_label, comparison_label,
TruncExp, UniformDuration,
)
1. Features#
A Feature declares one logical input or output channel: a name plus a
fixed dimensionality. Compose features with + (concatenate, immutable),
- (remove by name), | (alias for +), and *n (named repetition).
fix = Feature(1, 'fixation')
stim = Feature(8, 'stimulus')
choice = Feature(2, 'choice')
inputs = fix + stim
outputs = fix + choice
print('inputs:', inputs, 'num =', inputs.num)
print('stim slice in inputs:', inputs['stimulus'].i)
print('choice slice in outputs:', outputs['choice'].i)
inputs: FeatureSet(names=['fixation', 'stimulus'], nums=[1, 8]) num = 9
stim slice in inputs: slice(1, 9, None)
choice slice in outputs: slice(1, 3, None)
2. Declarative phases#
A DeclarativePhase describes its behavior with three dictionaries:
inputs={feature_name: value_spec, ...}— fill the input slice for one feature. A value spec is either a constant (broadcast) or a callablef(ctx, feature) -> array.outputs={...}— same shape conventions; in categorical mode use the reserved key'label', in vector mode write per output feature.noise={feature_name: sigma}— additive Gaussian noise scaled by1/sqrt(dt)so its signal variance is invariant under changes ofdt.
Compose phases sequentially with >> or concat([...]), repeat with *n,
and run two phases simultaneously with | (Parallel).
phases = (
Fixation(100 * u.ms, inputs={'fixation': 1.0}, outputs={'label': 0})
>> Stimulus(500 * u.ms,
inputs={'fixation': 1.0,
'stimulus': circular('direction', 'coherence')},
outputs={'label': 0})
>> Delay(300 * u.ms, inputs={'fixation': 1.0}, outputs={'label': 0})
>> Response(100 * u.ms,
inputs={'fixation': 0.0},
outputs={'label': lambda ctx, f: ctx['ground_truth'] + 1})
)
phases
Sequence(Fixation >> Stimulus >> Delay >> Response)
3. Encoders#
Encoders are helper factories that return callable value specs. The most commonly used ones:
encoder |
trial state → activation |
|---|---|
|
discrete class index → one-hot vector |
|
direction (radians or index) → cosine tuning over preferred directions |
|
direction → von-Mises tuning curve in |
|
discrete direction → repeated |
|
scalar value → Gaussian bumps over evenly-spaced centers |
|
scalar → broadcast |
|
array stored in |
|
raw value lookup; useful for time-varying inputs computed elsewhere |
Encoders read ctx[key] — typically a value set in trial_init.
4. Label helpers#
In categorical output mode every phase fills ctx.outputs[start:end] with
an integer label. The label helpers are convenient builders:
label(value)— static int, lookup from context (str), or arbitrary callable(ctx) -> int.match_label(match_key)— emit1on match trials,2otherwise.comparison_label(key)— emit1if greater,2otherwise.
5. Putting it together — instance-based#
The smallest end-to-end recipe: define features, define phases, define
trial_init, hand them to Task.
def trial_init(ctx):
ctx['ground_truth'] = ctx.rng.choice(2) # jax scalar — OK under JIT
ctx['coherence'] = 51.2
ctx['direction'] = ctx.rng.uniform(0.0, 2 * jnp.pi)
task = Task(
phases=phases,
input_features=inputs,
output_features=outputs,
trial_init=trial_init,
seed=0,
)
X, Y, info = task.sample_trial(0)
print('X.shape =', X.shape, 'Y.shape =', Y.shape)
info['trial_state']
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
X.shape = (1000, 9) Y.shape = (1000,)
{'trial_index': 0,
'ground_truth': Array(0, dtype=int32),
'coherence': 51.2,
'direction': Array(0.5190151, dtype=float32),
'output_mode': 'categorical'}
6. Putting it together — class-based#
For parameterized, reusable tasks, subclass Task and override
define_features, define_phases, and trial_init. Pre-built tasks
(PerceptualDecisionMaking, DelayMatchSample, …) all follow this
pattern.
class MyDMS(Task):
t_fixation = 200 * u.ms
t_sample = 400 * u.ms
t_delay = 600 * u.ms
t_response = 200 * u.ms
num_stimuli = 8
def define_features(self):
fix = Feature(1, 'fixation')
stim = Feature(self.num_stimuli, 'stimulus')
choice = Feature(2, 'choice') # match / non-match
return fix + stim, fix + choice
def define_phases(self):
return concat([
Fixation(self.t_fixation,
inputs={'fixation': 1.0},
outputs={'label': 0}),
Stimulus(self.t_sample,
inputs={'fixation': 1.0,
'stimulus': von_mises('sample_idx',
num_dirs=self.num_stimuli)},
outputs={'label': 0},
name='sample'),
Delay(self.t_delay,
inputs={'fixation': 1.0},
outputs={'label': 0}),
Response(self.t_response,
inputs={'fixation': 0.0,
'stimulus': von_mises('test_idx',
num_dirs=self.num_stimuli)},
outputs={'label': label(lambda ctx: jnp.where(ctx['is_match'], 1, 2))},
name='response'),
])
def trial_init(self, ctx):
# All values stay as jax scalars / arrays so trial_init works under JIT+vmap.
ctx['sample_idx'] = ctx.rng.choice(self.num_stimuli)
ctx['is_match'] = ctx.rng.uniform() < 0.5
non_match_idx = (ctx['sample_idx']
+ 1
+ ctx.rng.choice(self.num_stimuli - 1)) % self.num_stimuli
ctx['test_idx'] = jnp.where(ctx['is_match'], ctx['sample_idx'], non_match_idx)
task = MyDMS(num_stimuli=16, seed=42)
X, Y = task.batch_sample(64)
print('X.shape =', X.shape, 'Y.shape =', Y.shape)
X.shape = (1400, 64, 17) Y.shape = (1400, 64)
7. Vector outputs — continuous-report tasks#
Pass output_mode='vector' when the target is a population code rather
than a class index. Each phase then writes into named output features.
Features that should be silent during a phase are written explicitly with a
zero spec (a constant 0.0 or a small callable that returns zeros).
fix_in = Feature(1, 'fixation')
stim_in = Feature(16, 'stimulus')
fix_out = Feature(1, 'fixation_out')
dir_out = Feature(16, 'direction_out')
def silent(ctx, feat):
return jnp.zeros((feat.num,))
task = Task(
phases=concat([
Fixation(200 * u.ms,
inputs={'fixation': 1.0},
outputs={'fixation_out': 1.0, 'direction_out': silent}),
Stimulus(500 * u.ms,
inputs={'fixation': 1.0,
'stimulus': von_mises('sample_idx', num_dirs=16)},
outputs={'fixation_out': 1.0, 'direction_out': silent}),
Delay(800 * u.ms,
inputs={'fixation': 1.0},
outputs={'fixation_out': 1.0, 'direction_out': silent}),
Response(400 * u.ms,
inputs={'fixation': 0.0},
outputs={'fixation_out': 0.0,
'direction_out': von_mises('sample_idx', num_dirs=16)}),
]),
input_features=fix_in + stim_in,
output_features=fix_out + dir_out,
output_mode='vector',
trial_init=lambda ctx: ctx.update(sample_idx=ctx.rng.choice(16)),
seed=0,
)
X, Y = task.batch_sample(8)
print('X.shape =', X.shape, ' Y.shape =', Y.shape, '(B, T, 1+16)')
X.shape = (1900, 8, 17) Y.shape = (1900, 8, 17) (B, T, 1+16)
8. Branching with If / Switch#
If selects between then / else_ based on a predicate over the trial
context; Switch dispatches over a dict of cases keyed by the selector’s
output. Both must be resolvable from values written in trial_init,
because the framework needs to compute a deterministic total duration
before the actual encoding pass.
branching = (
Sample(400 * u.ms,
inputs={'stimulus': von_mises('sample_idx', num_dirs=8)},
outputs={'label': 0})
>> Delay(600 * u.ms,
inputs={'fixation': 1.0},
outputs={'label': 0})
>> Test(400 * u.ms,
inputs={'stimulus': von_mises('test_idx', num_dirs=8)},
outputs={'label': 0})
>> If(
condition=lambda ctx: bool(ctx['is_match']),
then=Response(200 * u.ms,
outputs={'label': 1}),
else_=Response(200 * u.ms,
outputs={'label': 2}),
)
)
branching
Sequence(Sample >> Delay >> Test >> If)
9. Looping with While#
While(condition, body, max_iterations) runs body until the predicate
returns False, capped by max_iterations. The duration computation uses
max_iterations as the upper bound, so the buffer is large enough for the
worst case — see Tutorial 3 for the JIT/vmap implications.
loop_demo = While(
condition=lambda ctx: ctx.get('evidence', 0.0) < ctx['threshold'],
body=Sample(50 * u.ms,
inputs={'stimulus': lambda ctx, f: ctx['pulse']},
outputs={'label': 0}),
max_iterations=20,
)
loop_demo
While(body=Sample, max=20)
10. Custom encoders#
The built-in encoders cover most use cases, but any callable with
signature f(ctx, feature) -> jnp.ndarray is a valid value spec. The
return value may be:
a scalar (broadcast to all timesteps and units),
a 1-D array of shape
(feature.num,)(broadcast along time),or a 2-D array of shape
(duration, feature.num)(written directly).
def ramping_pulse(start_val, stop_val):
def encode(ctx, feature):
duration = ctx.phase_duration
ramp = jnp.linspace(start_val, stop_val, duration)
return jnp.broadcast_to(ramp[:, None], (duration, feature.num))
encode.__name__ = f'ramping_pulse({start_val}, {stop_val})'
return encode
# Plug it in like any other encoder:
ramp_phase = Stimulus(
500 * u.ms,
inputs={'stimulus': ramping_pulse(0.0, 1.0)},
outputs={'label': 0},
)
ramp_phase
DeclarativePhase('Stimulus', inputs=['stimulus'])
Where to next#
Tutorial 3 — Variable-length trial sequences covers
VariableDuration, the packed-buffer + mask design, and howIf/Switch/Whileparticipate underbatch_sample.The API reference (
braintools.cogtaskunder API Reference) lists every pre-built task, encoder, phase, and utility with full parameter documentation.