Tutorial 3: Variable-length trial sequences#
Real cognitive experiments rarely use a single fixed timeline. Delay
periods are jittered to discourage timing strategies; decisions are made
when evidence reaches a threshold; trials branch on a cue. This tutorial
shows how braintools.cogtask supports such trials end-to-end under
jit and vmap.
The framework uses a packed-buffer + mask design:
Every trial in a batch is written into a buffer sized to the worst-case length,
task.max_trial_duration().Each trial reports its actual length via a boolean
maskof shape(T_max,).Truemarks live timesteps,Falsemarks padding.Phases write only into their valid slice; trailing positions stay at zero in
X/YandFalseinmask.Buffer shapes are static Python ints, so
brainstate.transform.jitandvmap2work without retracing.
This tutorial covers:
The
VariableDurationphase and how it’s used intrial_init.Detecting variable-length tasks (
task.is_variable_length,phase_tree_is_variable,task.max_trial_duration()).Sampling with masks (
sample_trial,batch_sample(return_mask=True)).Conditional control flow (
If,Switch,While) under packed mode.The migrated built-in tasks (
HierarchicalReasoning,IntervalDiscrimination,ReadySetGo).Consuming the mask in losses and metrics.
Duration samplers (
TruncExp,UniformDuration) and their bounds.
import brainstate
import brainunit as u
import jax
import jax.numpy as jnp
brainstate.environ.set(dt=1.0 * u.ms)
from braintools.cogtask import (
Task, Feature, concat,
Fixation, Delay, Response, Stimulus, Sample, Test,
VariableDuration,
If, Switch, While,
TruncExp, UniformDuration,
phase_tree_is_variable,
)
1. The VariableDuration phase#
VariableDuration is the declarative primitive for “this phase lasts a
trial-dependent number of steps.” It looks like any other declarative
phase (it takes inputs=, outputs=, noise=), but its duration is
read from the trial context at sample time:
min_duration/max_durationarebrainunitquantities. They bound the phase:min_durationfloors the step count andmax_durationis the static upper bound used to size the buffer slot.ctx_keynames the trial-state entry holding the actual duration for this trial.trial_initwrites a scalar (a float indtunits or a Quantity) intoctx[ctx_key].
The framework converts ctx[ctx_key] to a step count, clips it into
[1, ceil(max_duration / dt)], and writes only that many timesteps.
variable_delay = VariableDuration(
min_duration=200 * u.ms,
max_duration=1500 * u.ms,
ctx_key='delay_duration',
inputs={'fixation': 1.0},
outputs={'label': 0},
name='variable_delay',
)
variable_delay
DeclarativePhase('variable_delay', inputs=['fixation'])
A minimal variable-delay task#
A delay-match-sample task with a delay drawn uniformly per trial:
fix = Feature(1, 'fixation')
stim = Feature(2, 'stim')
choice = Feature(2, 'choice')
phases = concat([
Fixation(50 * u.ms, inputs={'fixation': 1.0}, outputs={'label': 0}),
Sample(40 * u.ms,
inputs={'stim': lambda c, f: jnp.ones(f.num)},
outputs={'label': 0}),
VariableDuration(
min_duration=200 * u.ms,
max_duration=1500 * u.ms,
ctx_key='delay_duration',
inputs={'fixation': 1.0},
outputs={'label': 0},
name='delay',
),
Response(50 * u.ms,
outputs={'label': lambda c, f: c['gt']}),
])
def init(ctx):
# Both values are JAX scalars — trial_init runs under jit/vmap.
ctx['delay_duration'] = ctx.rng.uniform(200.0, 1500.0)
ctx['gt'] = ctx.rng.choice(2).astype(jnp.int32) + 1
task = Task(
phases=phases,
input_features=fix + stim,
output_features=fix + choice,
trial_init=init,
seed=0,
)
task
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Task(name=Task, inputs=3, outputs=3, output_mode=categorical)
2. Detecting variable-length tasks#
Task walks the phase tree once at construction time and records
whether any phase advertises is_variable = True. Three properties
let you reason about the resulting buffers.
task.is_variable_length—Trueif the task uses the packed path.task.max_trial_duration()— Pythonint, the worst-case timestep count. This is the staticTused bysample_trialandbatch_sample. Safe as a buffer dimension underjit/vmap.phase_tree_is_variable(phases)— module-level helper that walks any phase subtree; useful when composing trees outside aTask.
print('is_variable_length :', task.is_variable_length)
print('max_trial_duration() :', task.max_trial_duration(), 'steps at dt=1ms')
# 50 + 40 + 1500 + 50 = 1640
# The same detection at the phase-tree level:
print('phase_tree_is_variable:', phase_tree_is_variable(phases))
is_variable_length : True
max_trial_duration() : 1640 steps at dt=1ms
phase_tree_is_variable: True
3. Sampling with masks#
For variable-length tasks, sample_trial returns the usual (X, Y, info)
triple, but info['mask'] is now a (T_max,) boolean array. For
fixed-length tasks info['mask'] is None.
batch_sample(B, return_mask=True) is the JIT/vmap path: it returns
(X, Y, mask) with mask in the same time/batch layout as X and Y
(default time-first → (T_max, B); pass time_first=False for
(B, T_max)).
X, Y, info = task.sample_trial(0)
mask = info['mask']
print('X.shape =', X.shape, ' (T_max, num_inputs)')
print('Y.shape =', Y.shape, ' (T_max,)')
print('mask.shape =', mask.shape, ' mask.dtype =', mask.dtype)
print('valid steps for trial 0 =', int(jnp.sum(mask)))
print('full T_max =', X.shape[0])
X.shape = (1640, 3) (T_max, num_inputs)
Y.shape = (1640,) (T_max,)
mask.shape = (1640,) mask.dtype = bool
valid steps for trial 0 = 475
full T_max = 1640
# JIT + vmap path: stack 8 trials with heterogeneous delay lengths.
X, Y, mask = task.batch_sample(8, return_mask=True)
print('X.shape =', X.shape, ' (T_max, B, num_inputs)')
print('Y.shape =', Y.shape, ' (T_max, B)')
print('mask.shape =', mask.shape, ' (T_max, B)')
# Each column of mask records that trial's length:
lengths = mask.sum(axis=0)
print('per-trial valid lengths =', lengths)
X.shape = (1640, 8, 3) (T_max, B, num_inputs)
Y.shape = (1640, 8) (T_max, B)
mask.shape = (1640, 8) (T_max, B)
per-trial valid lengths = [ 475 458 1309 1443 1132 1636 827 358]
Shape contract#
Call |
Returns |
Shapes |
|---|---|---|
|
|
|
|
|
|
|
|
as above, plus |
|
|
|
return_mask=True is supported on fixed-length tasks too — the
returned mask is simply all-True. That means a downstream training
loop can call batch_sample(..., return_mask=True) unconditionally
without branching on task.is_variable_length.
Trailing positions are zero#
Phases write only into their valid slice. Anything past a trial’s
actual length stays at the buffer default (0 in X, 0 in Y,
False in mask).
# Pull a single trial back out of the batch and look at its tail.
trial_idx = 0
valid = int(mask[:, trial_idx].sum())
print(f'trial {trial_idx}: {valid} valid steps, {X.shape[0] - valid} padded steps')
print('tail of X :', X[valid:valid + 4, trial_idx])
print('tail of Y :', Y[valid:valid + 4, trial_idx])
print('tail mask :', mask[valid:valid + 4, trial_idx])
trial 0: 475 valid steps, 1165 padded steps
tail of X : [[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
tail of Y : [0 0 0 0]
tail mask : [False False False False]
4. Conditional control flow under packed mode#
If, Switch, and While all participate in variable-length trees.
Their max_steps is the static upper bound the framework allocates for;
their step_count reports what actually ran on this trial.
Phase |
Semantics under packed mode |
Constraints |
|---|---|---|
|
Both branches contribute via |
The predicate must read trial state ( |
|
Python-level dispatch on the selector’s value. |
Selector must return a hashable Python key, not a tracer (set it in |
|
Python-level loop bounded by |
Condition must return a Python |
Branches that don’t run leave their buffer region at zero and do not
advance t_cursor, so the mask stays False for the unused slot.
# If: cued go/no-go where 'go' is sampled per trial.
go_or_nogo = concat([
Fixation(20 * u.ms, inputs={'fixation': 1.0}),
If(
condition=lambda ctx: ctx['go'],
then=Stimulus(40 * u.ms,
inputs={'stim': lambda c, f: jnp.ones(f.num)}),
else_=Fixation(40 * u.ms,
inputs={'fixation': 0.5}),
),
Response(20 * u.ms, outputs={'label': lambda c, f: c['gt']}),
])
def go_init(ctx):
ctx['go'] = ctx.rng.choice(2).astype(jnp.bool_)
ctx['gt'] = ctx.rng.choice(2).astype(jnp.int32) + 1
go_task = Task(
phases=go_or_nogo,
input_features=fix + stim,
output_features=fix + choice,
trial_init=go_init,
seed=0,
)
print('is_variable_length :', go_task.is_variable_length)
print('max_trial_duration :', go_task.max_trial_duration(),
'(20 + max(40, 40) + 20)')
X, Y, M = go_task.batch_sample(8, return_mask=True)
print('mask sums per trial :', M.sum(axis=0))
is_variable_length : True
max_trial_duration : 80 (20 + max(40, 40) + 20)
mask sums per trial : [80 80 80 80 80 80 80 80]
5. Migrated built-in tasks#
Three pre-built tasks in braintools.cogtask now use VariableDuration
internally and work with batch_sample out of the box:
HierarchicalReasoning— variable delay between two flash cues.IntervalDiscrimination— two stimulus intervals sampled independently per trial.ReadySetGo— measurement interval sampled per trial.
from braintools.cogtask import (
HierarchicalReasoning, IntervalDiscrimination, ReadySetGo,
)
for cls in [HierarchicalReasoning, IntervalDiscrimination, ReadySetGo]:
task = cls(seed=42)
X, Y, M = task.batch_sample(4, return_mask=True)
print(f'{cls.__name__:25s} T_max={task.max_trial_duration():5d}'
f' variable={task.is_variable_length}'
f' mask sums={list(M.sum(axis=0))}')
HierarchicalReasoning T_max= 2000 variable=True mask sums=[Array(1649, dtype=int32), Array(1489, dtype=int32), Array(1537, dtype=int32), Array(1862, dtype=int32)]
IntervalDiscrimination T_max= 3100 variable=True mask sums=[Array(2604, dtype=int32), Array(2634, dtype=int32), Array(2646, dtype=int32), Array(2778, dtype=int32)]
ReadySetGo T_max= 2800 variable=True mask sums=[Array(2566, dtype=int32), Array(2459, dtype=int32), Array(2491, dtype=int32), Array(2708, dtype=int32)]
6. Using the mask in losses and metrics#
The standard recipe for a masked cross-entropy loss: compute the loss elementwise, multiply by the mask cast to float, then normalize by the mask’s sum rather than the buffer size.
def masked_cross_entropy(logits, labels, mask):
# logits: (T, B, C) labels: (T, B) int mask: (T, B) bool
# Use log-softmax for numerical stability.
logp = jax.nn.log_softmax(logits, axis=-1)
onehot = jax.nn.one_hot(labels, logp.shape[-1])
nll = -jnp.sum(onehot * logp, axis=-1) # (T, B)
mask_f = mask.astype(nll.dtype)
return jnp.sum(nll * mask_f) / jnp.maximum(jnp.sum(mask_f), 1.0)
# Toy demo: pretend a network produced uniform logits.
X, Y, M = task.batch_sample(4, return_mask=True)
logits = jnp.zeros(X.shape[:2] + (task.num_outputs,))
print('masked loss =', float(masked_cross_entropy(logits, Y, M)))
masked loss = 0.5730118751525879
The same pattern applies to vector-output tasks (use mse,
cos_sin-style population loss, etc. and gate by mask[..., None]) and
to accuracy metrics ((preds == Y) * mask divided by mask.sum()).
7. Duration samplers#
TruncExp and UniformDuration are the canonical helpers for drawing a
per-trial duration in trial_init. They are JIT/vmap-safe (they
consume ctx.rng) and they advertise the static bounds the framework
needs to size buffers:
sampler.is_variable— class attribute set toTrue.sampler.min_value(),sampler.max_value()— return the bounds as Quantities.VariableDuration.min_duration/max_durationshould match (or sit just outside) these bounds.
te = TruncExp(mean=600 * u.ms, min_val=300 * u.ms, max_val=1500 * u.ms)
ud = UniformDuration(200 * u.ms, 800 * u.ms)
print('TruncExp bounds:', te.min_value(), '-', te.max_value(),
' is_variable =', te.is_variable)
print('UniformDuration bounds:', ud.min_value(), '-', ud.max_value(),
' is_variable =', ud.is_variable)
TruncExp bounds: 300 ms - 1500 ms is_variable = True
UniformDuration bounds: 200 ms - 800 ms is_variable = True
# Wire a sampler into trial_init -> VariableDuration:
delay_dist = TruncExp(mean=600 * u.ms, min_val=300 * u.ms, max_val=1500 * u.ms)
def init_with_sampler(ctx):
# Store the sampled Quantity (or its mantissa) in the ctx key the
# phase reads. Either form works.
ctx['delay_duration'] = delay_dist(ctx).to(u.ms).mantissa
ctx['gt'] = ctx.rng.choice(2).astype(jnp.int32) + 1
sampler_task = Task(
phases=concat([
Fixation(50 * u.ms, inputs={'fixation': 1.0}),
VariableDuration(
min_duration=delay_dist.min_value(),
max_duration=delay_dist.max_value(),
ctx_key='delay_duration',
inputs={'fixation': 1.0},
),
Response(50 * u.ms, outputs={'label': lambda c, f: c['gt']}),
]),
input_features=fix + stim,
output_features=fix + choice,
trial_init=init_with_sampler,
seed=0,
)
X, Y, M = sampler_task.batch_sample(8, return_mask=True)
print('T_max =', sampler_task.max_trial_duration(),
' per-trial lengths =', list(M.sum(axis=0)))
T_max = 1600 per-trial lengths = [Array(456, dtype=int32), Array(449, dtype=int32), Array(1020, dtype=int32), Array(1194, dtype=int32), Array(849, dtype=int32), Array(1588, dtype=int32), Array(635, dtype=int32), Array(407, dtype=int32)]
Summary#
Mark variable-length epochs with
VariableDuration(min_duration, max_duration, ctx_key=...)and write the per-trial length intoctx[ctx_key]fromtrial_init.The framework auto-detects variable-length trees and switches to a packed-buffer path.
task.is_variable_lengthandtask.max_trial_duration()describe the result.Use
batch_sample(B, return_mask=True)to get aligned(X, Y, mask)buffers underjit+vmap. The mask doubles as a per-step weight for losses and metrics.If/Switch/Whileparticipate in the same buffers; conditional branches that didn’t run leave their slot at zero andmask=False.TruncExpandUniformDurationare sampling helpers whosemin_value()/max_value()line up withVariableDuration’smin_duration/max_duration.
For the full API surface — every phase type, encoder, label helper,
duration sampler, and pre-built task — see the
braintools.cogtask API reference.