Tutorial 2: Building custom tasks#

This tutorial covers the building blocks you assemble into a custom braintools.cogtask.Task:

  1. FeaturesFeature / FeatureSet, composition, and the .i slice.

  2. Declarative phasesFixation, Stimulus, Delay, Response, …

  3. Encoders — value specs like one_hot, circular, von_mises, cos_sin, gaussian, scalar, identity, ctx_value.

  4. Labelslabel, match_label, comparison_label.

  5. Putting it together — both the instance-based and the class-based patterns.

  6. Vector outputs — using output_mode='vector' for continuous-report tasks.

  7. Branching with If / Switch / While.

  8. Custom encoders.


import brainunit as u
import jax.numpy as jnp
import brainstate

brainstate.environ.set(dt=1.0 * u.ms)

from braintools.cogtask import (
    Task, Context, concat,
    Phase, Sequence, Repeat, Parallel,
    Fixation, Stimulus, Delay, Response, Cue, Blank,
    Sample, Test, Recall, Match, Comparison,
    If, Switch, While,
    Feature, FeatureSet, CircleFeature,
    one_hot, circular, von_mises, cos_sin, gaussian, scalar, identity, ctx_value,
    label, match_label, comparison_label,
    TruncExp, UniformDuration,
)

1. Features#

A Feature declares one logical input or output channel: a name plus a fixed dimensionality. Compose features with + (concatenate, immutable), - (remove by name), | (alias for +), and *n (named repetition).

fix    = Feature(1, 'fixation')
stim   = Feature(8, 'stimulus')
choice = Feature(2, 'choice')

inputs  = fix + stim
outputs = fix + choice
print('inputs:', inputs, 'num =', inputs.num)
print('stim slice in inputs:', inputs['stimulus'].i)
print('choice slice in outputs:', outputs['choice'].i)
inputs: FeatureSet(names=['fixation', 'stimulus'], nums=[1, 8]) num = 9
stim slice in inputs: slice(1, 9, None)
choice slice in outputs: slice(1, 3, None)

2. Declarative phases#

A DeclarativePhase describes its behavior with three dictionaries:

  • inputs={feature_name: value_spec, ...} — fill the input slice for one feature. A value spec is either a constant (broadcast) or a callable f(ctx, feature) -> array.

  • outputs={...} — same shape conventions; in categorical mode use the reserved key 'label', in vector mode write per output feature.

  • noise={feature_name: sigma} — additive Gaussian noise scaled by 1/sqrt(dt) so its signal variance is invariant under changes of dt.

Compose phases sequentially with >> or concat([...]), repeat with *n, and run two phases simultaneously with | (Parallel).

phases = (
    Fixation(100 * u.ms, inputs={'fixation': 1.0}, outputs={'label': 0})
    >> Stimulus(500 * u.ms,
                inputs={'fixation': 1.0,
                        'stimulus': circular('direction', 'coherence')},
                outputs={'label': 0})
    >> Delay(300 * u.ms, inputs={'fixation': 1.0}, outputs={'label': 0})
    >> Response(100 * u.ms,
                inputs={'fixation': 0.0},
                outputs={'label': lambda ctx, f: ctx['ground_truth'] + 1})
)
phases
Sequence(Fixation >> Stimulus >> Delay >> Response)

3. Encoders#

Encoders are helper factories that return callable value specs. The most commonly used ones:

encoder

trial state → activation

one_hot(key)

discrete class index → one-hot vector

circular(key, coherence_key)

direction (radians or index) → cosine tuning over preferred directions

von_mises(key, …)

direction → von-Mises tuning curve in [base_value, 1]

cos_sin(key, num_dirs, …)

discrete direction → repeated [cos θ, sin θ] features

gaussian(key, sigma=…)

scalar value → Gaussian bumps over evenly-spaced centers

scalar(key, scale, offset)

scalar → broadcast value * scale + offset to all units

identity(key)

array stored in ctx[key] → written through unchanged

ctx_value(key, default=…)

raw value lookup; useful for time-varying inputs computed elsewhere

Encoders read ctx[key] — typically a value set in trial_init.

4. Label helpers#

In categorical output mode every phase fills ctx.outputs[start:end] with an integer label. The label helpers are convenient builders:

  • label(value) — static int, lookup from context (str), or arbitrary callable (ctx) -> int.

  • match_label(match_key) — emit 1 on match trials, 2 otherwise.

  • comparison_label(key) — emit 1 if greater, 2 otherwise.

5. Putting it together — instance-based#

The smallest end-to-end recipe: define features, define phases, define trial_init, hand them to Task.

def trial_init(ctx):
    ctx['ground_truth'] = ctx.rng.choice(2)              # jax scalar — OK under JIT
    ctx['coherence']    = 51.2
    ctx['direction']    = ctx.rng.uniform(0.0, 2 * jnp.pi)

task = Task(
    phases=phases,
    input_features=inputs,
    output_features=outputs,
    trial_init=trial_init,
    seed=0,
)

X, Y, info = task.sample_trial(0)
print('X.shape =', X.shape, 'Y.shape =', Y.shape)
info['trial_state']
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
X.shape = (1000, 9) Y.shape = (1000,)
{'trial_index': 0,
 'ground_truth': Array(0, dtype=int32),
 'coherence': 51.2,
 'direction': Array(0.5190151, dtype=float32),
 'output_mode': 'categorical'}

6. Putting it together — class-based#

For parameterized, reusable tasks, subclass Task and override define_features, define_phases, and trial_init. Pre-built tasks (PerceptualDecisionMaking, DelayMatchSample, …) all follow this pattern.

class MyDMS(Task):
    t_fixation = 200 * u.ms
    t_sample   = 400 * u.ms
    t_delay    = 600 * u.ms
    t_response = 200 * u.ms
    num_stimuli = 8

    def define_features(self):
        fix    = Feature(1, 'fixation')
        stim   = Feature(self.num_stimuli, 'stimulus')
        choice = Feature(2, 'choice')   # match / non-match
        return fix + stim, fix + choice

    def define_phases(self):
        return concat([
            Fixation(self.t_fixation,
                     inputs={'fixation': 1.0},
                     outputs={'label': 0}),
            Stimulus(self.t_sample,
                     inputs={'fixation': 1.0,
                             'stimulus': von_mises('sample_idx',
                                                   num_dirs=self.num_stimuli)},
                     outputs={'label': 0},
                     name='sample'),
            Delay(self.t_delay,
                  inputs={'fixation': 1.0},
                  outputs={'label': 0}),
            Response(self.t_response,
                     inputs={'fixation': 0.0,
                             'stimulus': von_mises('test_idx',
                                                   num_dirs=self.num_stimuli)},
                     outputs={'label': label(lambda ctx: jnp.where(ctx['is_match'], 1, 2))},
                     name='response'),
        ])

    def trial_init(self, ctx):
        # All values stay as jax scalars / arrays so trial_init works under JIT+vmap.
        ctx['sample_idx'] = ctx.rng.choice(self.num_stimuli)
        ctx['is_match']   = ctx.rng.uniform() < 0.5
        non_match_idx     = (ctx['sample_idx']
                             + 1
                             + ctx.rng.choice(self.num_stimuli - 1)) % self.num_stimuli
        ctx['test_idx']   = jnp.where(ctx['is_match'], ctx['sample_idx'], non_match_idx)

task = MyDMS(num_stimuli=16, seed=42)
X, Y = task.batch_sample(64)
print('X.shape =', X.shape, 'Y.shape =', Y.shape)
X.shape = (1400, 64, 17) Y.shape = (1400, 64)

7. Vector outputs — continuous-report tasks#

Pass output_mode='vector' when the target is a population code rather than a class index. Each phase then writes into named output features. Features that should be silent during a phase are written explicitly with a zero spec (a constant 0.0 or a small callable that returns zeros).

fix_in  = Feature(1, 'fixation')
stim_in = Feature(16, 'stimulus')
fix_out = Feature(1, 'fixation_out')
dir_out = Feature(16, 'direction_out')

def silent(ctx, feat):
    return jnp.zeros((feat.num,))

task = Task(
    phases=concat([
        Fixation(200 * u.ms,
                 inputs={'fixation': 1.0},
                 outputs={'fixation_out': 1.0, 'direction_out': silent}),
        Stimulus(500 * u.ms,
                 inputs={'fixation': 1.0,
                         'stimulus': von_mises('sample_idx', num_dirs=16)},
                 outputs={'fixation_out': 1.0, 'direction_out': silent}),
        Delay(800 * u.ms,
              inputs={'fixation': 1.0},
              outputs={'fixation_out': 1.0, 'direction_out': silent}),
        Response(400 * u.ms,
                 inputs={'fixation': 0.0},
                 outputs={'fixation_out': 0.0,
                          'direction_out': von_mises('sample_idx', num_dirs=16)}),
    ]),
    input_features=fix_in + stim_in,
    output_features=fix_out + dir_out,
    output_mode='vector',
    trial_init=lambda ctx: ctx.update(sample_idx=ctx.rng.choice(16)),
    seed=0,
)

X, Y = task.batch_sample(8)
print('X.shape =', X.shape, '  Y.shape =', Y.shape, '(B, T, 1+16)')
X.shape = (1900, 8, 17)   Y.shape = (1900, 8, 17) (B, T, 1+16)

8. Branching with If / Switch#

If selects between then / else_ based on a predicate over the trial context; Switch dispatches over a dict of cases keyed by the selector’s output. Both must be resolvable from values written in trial_init, because the framework needs to compute a deterministic total duration before the actual encoding pass.

branching = (
    Sample(400 * u.ms,
           inputs={'stimulus': von_mises('sample_idx', num_dirs=8)},
           outputs={'label': 0})
    >> Delay(600 * u.ms,
             inputs={'fixation': 1.0},
             outputs={'label': 0})
    >> Test(400 * u.ms,
            inputs={'stimulus': von_mises('test_idx', num_dirs=8)},
            outputs={'label': 0})
    >> If(
        condition=lambda ctx: bool(ctx['is_match']),
        then=Response(200 * u.ms,
                      outputs={'label': 1}),
        else_=Response(200 * u.ms,
                       outputs={'label': 2}),
    )
)
branching
Sequence(Sample >> Delay >> Test >> If)

9. Looping with While#

While(condition, body, max_iterations) runs body until the predicate returns False, capped by max_iterations. The duration computation uses max_iterations as the upper bound, so the buffer is large enough for the worst case — see Tutorial 3 for the JIT/vmap implications.

loop_demo = While(
    condition=lambda ctx: ctx.get('evidence', 0.0) < ctx['threshold'],
    body=Sample(50 * u.ms,
                inputs={'stimulus': lambda ctx, f: ctx['pulse']},
                outputs={'label': 0}),
    max_iterations=20,
)
loop_demo
While(body=Sample, max=20)

10. Custom encoders#

The built-in encoders cover most use cases, but any callable with signature f(ctx, feature) -> jnp.ndarray is a valid value spec. The return value may be:

  • a scalar (broadcast to all timesteps and units),

  • a 1-D array of shape (feature.num,) (broadcast along time),

  • or a 2-D array of shape (duration, feature.num) (written directly).

def ramping_pulse(start_val, stop_val):
    def encode(ctx, feature):
        duration = ctx.phase_duration
        ramp = jnp.linspace(start_val, stop_val, duration)
        return jnp.broadcast_to(ramp[:, None], (duration, feature.num))
    encode.__name__ = f'ramping_pulse({start_val}, {stop_val})'
    return encode

# Plug it in like any other encoder:
ramp_phase = Stimulus(
    500 * u.ms,
    inputs={'stimulus': ramping_pulse(0.0, 1.0)},
    outputs={'label': 0},
)
ramp_phase
DeclarativePhase('Stimulus', inputs=['stimulus'])

Where to next#

  • Tutorial 3 — Variable-length trial sequences covers VariableDuration, the packed-buffer + mask design, and how If / Switch / While participate under batch_sample.

  • The API reference (braintools.cogtask under API Reference) lists every pre-built task, encoder, phase, and utility with full parameter documentation.