Tutorial 3: Variable-length trial sequences#

Real cognitive experiments rarely use a single fixed timeline. Delay periods are jittered to discourage timing strategies; decisions are made when evidence reaches a threshold; trials branch on a cue. This tutorial shows how braintools.cogtask supports such trials end-to-end under jit and vmap.

The framework uses a packed-buffer + mask design:

  • Every trial in a batch is written into a buffer sized to the worst-case length, task.max_trial_duration().

  • Each trial reports its actual length via a boolean mask of shape (T_max,). True marks live timesteps, False marks padding.

  • Phases write only into their valid slice; trailing positions stay at zero in X/Y and False in mask.

  • Buffer shapes are static Python ints, so brainstate.transform.jit and vmap2 work without retracing.

This tutorial covers:

  1. The VariableDuration phase and how it’s used in trial_init.

  2. Detecting variable-length tasks (task.is_variable_length, phase_tree_is_variable, task.max_trial_duration()).

  3. Sampling with masks (sample_trial, batch_sample(return_mask=True)).

  4. Conditional control flow (If, Switch, While) under packed mode.

  5. The migrated built-in tasks (HierarchicalReasoning, IntervalDiscrimination, ReadySetGo).

  6. Consuming the mask in losses and metrics.

  7. Duration samplers (TruncExp, UniformDuration) and their bounds.


import brainstate
import brainunit as u
import jax
import jax.numpy as jnp

brainstate.environ.set(dt=1.0 * u.ms)

from braintools.cogtask import (
    Task, Feature, concat,
    Fixation, Delay, Response, Stimulus, Sample, Test,
    VariableDuration,
    If, Switch, While,
    TruncExp, UniformDuration,
    phase_tree_is_variable,
)

1. The VariableDuration phase#

VariableDuration is the declarative primitive for “this phase lasts a trial-dependent number of steps.” It looks like any other declarative phase (it takes inputs=, outputs=, noise=), but its duration is read from the trial context at sample time:

  • min_duration / max_duration are brainunit quantities. They bound the phase: min_duration floors the step count and max_duration is the static upper bound used to size the buffer slot.

  • ctx_key names the trial-state entry holding the actual duration for this trial. trial_init writes a scalar (a float in dt units or a Quantity) into ctx[ctx_key].

The framework converts ctx[ctx_key] to a step count, clips it into [1, ceil(max_duration / dt)], and writes only that many timesteps.

variable_delay = VariableDuration(
    min_duration=200 * u.ms,
    max_duration=1500 * u.ms,
    ctx_key='delay_duration',
    inputs={'fixation': 1.0},
    outputs={'label': 0},
    name='variable_delay',
)
variable_delay
DeclarativePhase('variable_delay', inputs=['fixation'])

A minimal variable-delay task#

A delay-match-sample task with a delay drawn uniformly per trial:

fix    = Feature(1, 'fixation')
stim   = Feature(2, 'stim')
choice = Feature(2, 'choice')

phases = concat([
    Fixation(50 * u.ms, inputs={'fixation': 1.0}, outputs={'label': 0}),
    Sample(40 * u.ms,
           inputs={'stim': lambda c, f: jnp.ones(f.num)},
           outputs={'label': 0}),
    VariableDuration(
        min_duration=200 * u.ms,
        max_duration=1500 * u.ms,
        ctx_key='delay_duration',
        inputs={'fixation': 1.0},
        outputs={'label': 0},
        name='delay',
    ),
    Response(50 * u.ms,
             outputs={'label': lambda c, f: c['gt']}),
])

def init(ctx):
    # Both values are JAX scalars — trial_init runs under jit/vmap.
    ctx['delay_duration'] = ctx.rng.uniform(200.0, 1500.0)
    ctx['gt']             = ctx.rng.choice(2).astype(jnp.int32) + 1

task = Task(
    phases=phases,
    input_features=fix + stim,
    output_features=fix + choice,
    trial_init=init,
    seed=0,
)
task
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Task(name=Task, inputs=3, outputs=3, output_mode=categorical)

2. Detecting variable-length tasks#

Task walks the phase tree once at construction time and records whether any phase advertises is_variable = True. Three properties let you reason about the resulting buffers.

  • task.is_variable_lengthTrue if the task uses the packed path.

  • task.max_trial_duration() — Python int, the worst-case timestep count. This is the static T used by sample_trial and batch_sample. Safe as a buffer dimension under jit/vmap.

  • phase_tree_is_variable(phases) — module-level helper that walks any phase subtree; useful when composing trees outside a Task.

print('is_variable_length    :', task.is_variable_length)
print('max_trial_duration()  :', task.max_trial_duration(), 'steps at dt=1ms')
# 50 + 40 + 1500 + 50 = 1640

# The same detection at the phase-tree level:
print('phase_tree_is_variable:', phase_tree_is_variable(phases))
is_variable_length    : True
max_trial_duration()  : 1640 steps at dt=1ms
phase_tree_is_variable: True

3. Sampling with masks#

For variable-length tasks, sample_trial returns the usual (X, Y, info) triple, but info['mask'] is now a (T_max,) boolean array. For fixed-length tasks info['mask'] is None.

batch_sample(B, return_mask=True) is the JIT/vmap path: it returns (X, Y, mask) with mask in the same time/batch layout as X and Y (default time-first → (T_max, B); pass time_first=False for (B, T_max)).

X, Y, info = task.sample_trial(0)
mask = info['mask']
print('X.shape    =', X.shape, '  (T_max, num_inputs)')
print('Y.shape    =', Y.shape, '  (T_max,)')
print('mask.shape =', mask.shape, ' mask.dtype =', mask.dtype)
print('valid steps for trial 0 =', int(jnp.sum(mask)))
print('full T_max              =', X.shape[0])
X.shape    = (1640, 3)   (T_max, num_inputs)
Y.shape    = (1640,)   (T_max,)
mask.shape = (1640,)  mask.dtype = bool
valid steps for trial 0 = 475
full T_max              = 1640
# JIT + vmap path: stack 8 trials with heterogeneous delay lengths.
X, Y, mask = task.batch_sample(8, return_mask=True)
print('X.shape    =', X.shape, '  (T_max, B, num_inputs)')
print('Y.shape    =', Y.shape, '  (T_max, B)')
print('mask.shape =', mask.shape, ' (T_max, B)')

# Each column of mask records that trial's length:
lengths = mask.sum(axis=0)
print('per-trial valid lengths =', lengths)
X.shape    = (1640, 8, 3)   (T_max, B, num_inputs)
Y.shape    = (1640, 8)   (T_max, B)
mask.shape = (1640, 8)  (T_max, B)
per-trial valid lengths = [ 475  458 1309 1443 1132 1636  827  358]

Shape contract#

Call

Returns

Shapes

task.sample_trial(i)

(X, Y, info)

X: (T_max, F), Y: (T_max,) or (T_max, F_out), info['mask']: (T_max,) bool

task.batch_sample(B)

(X, Y)

X: (T_max, B, F), Y: (T_max, B) or (T_max, B, F_out)

task.batch_sample(B, return_mask=True)

(X, Y, mask)

as above, plus mask: (T_max, B) bool

task.batch_sample(B, time_first=False, return_mask=True)

(X, Y, mask)

X: (B, T_max, F), Y: (B, T_max), mask: (B, T_max)

return_mask=True is supported on fixed-length tasks too — the returned mask is simply all-True. That means a downstream training loop can call batch_sample(..., return_mask=True) unconditionally without branching on task.is_variable_length.

Trailing positions are zero#

Phases write only into their valid slice. Anything past a trial’s actual length stays at the buffer default (0 in X, 0 in Y, False in mask).

# Pull a single trial back out of the batch and look at its tail.
trial_idx = 0
valid = int(mask[:, trial_idx].sum())
print(f'trial {trial_idx}: {valid} valid steps, {X.shape[0] - valid} padded steps')
print('tail of X  :', X[valid:valid + 4, trial_idx])
print('tail of Y  :', Y[valid:valid + 4, trial_idx])
print('tail mask  :', mask[valid:valid + 4, trial_idx])
trial 0: 475 valid steps, 1165 padded steps
tail of X  : [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
tail of Y  : [0 0 0 0]
tail mask  : [False False False False]

4. Conditional control flow under packed mode#

If, Switch, and While all participate in variable-length trees. Their max_steps is the static upper bound the framework allocates for; their step_count reports what actually ran on this trial.

Phase

Semantics under packed mode

Constraints

If

Both branches contribute via jax.lax.cond. Buffer is sized to max(then_max, else_max).

The predicate must read trial state (ctx[...]); it may be a JAX tracer.

Switch

Python-level dispatch on the selector’s value.

Selector must return a hashable Python key, not a tracer (set it in trial_init as a Python value).

While

Python-level loop bounded by max_iterations.

Condition must return a Python bool. Buffer = body.max_steps * max_iterations.

Branches that don’t run leave their buffer region at zero and do not advance t_cursor, so the mask stays False for the unused slot.

# If: cued go/no-go where 'go' is sampled per trial.
go_or_nogo = concat([
    Fixation(20 * u.ms, inputs={'fixation': 1.0}),
    If(
        condition=lambda ctx: ctx['go'],
        then=Stimulus(40 * u.ms,
                      inputs={'stim': lambda c, f: jnp.ones(f.num)}),
        else_=Fixation(40 * u.ms,
                       inputs={'fixation': 0.5}),
    ),
    Response(20 * u.ms, outputs={'label': lambda c, f: c['gt']}),
])

def go_init(ctx):
    ctx['go'] = ctx.rng.choice(2).astype(jnp.bool_)
    ctx['gt'] = ctx.rng.choice(2).astype(jnp.int32) + 1

go_task = Task(
    phases=go_or_nogo,
    input_features=fix + stim,
    output_features=fix + choice,
    trial_init=go_init,
    seed=0,
)

print('is_variable_length  :', go_task.is_variable_length)
print('max_trial_duration  :', go_task.max_trial_duration(),
      '(20 + max(40, 40) + 20)')

X, Y, M = go_task.batch_sample(8, return_mask=True)
print('mask sums per trial :', M.sum(axis=0))
is_variable_length  : True
max_trial_duration  : 80 (20 + max(40, 40) + 20)
mask sums per trial : [80 80 80 80 80 80 80 80]

5. Migrated built-in tasks#

Three pre-built tasks in braintools.cogtask now use VariableDuration internally and work with batch_sample out of the box:

  • HierarchicalReasoning — variable delay between two flash cues.

  • IntervalDiscrimination — two stimulus intervals sampled independently per trial.

  • ReadySetGo — measurement interval sampled per trial.

from braintools.cogtask import (
    HierarchicalReasoning, IntervalDiscrimination, ReadySetGo,
)

for cls in [HierarchicalReasoning, IntervalDiscrimination, ReadySetGo]:
    task = cls(seed=42)
    X, Y, M = task.batch_sample(4, return_mask=True)
    print(f'{cls.__name__:25s} T_max={task.max_trial_duration():5d}'
          f'  variable={task.is_variable_length}'
          f'  mask sums={list(M.sum(axis=0))}')
HierarchicalReasoning     T_max= 2000  variable=True  mask sums=[Array(1649, dtype=int32), Array(1489, dtype=int32), Array(1537, dtype=int32), Array(1862, dtype=int32)]
IntervalDiscrimination    T_max= 3100  variable=True  mask sums=[Array(2604, dtype=int32), Array(2634, dtype=int32), Array(2646, dtype=int32), Array(2778, dtype=int32)]
ReadySetGo                T_max= 2800  variable=True  mask sums=[Array(2566, dtype=int32), Array(2459, dtype=int32), Array(2491, dtype=int32), Array(2708, dtype=int32)]

6. Using the mask in losses and metrics#

The standard recipe for a masked cross-entropy loss: compute the loss elementwise, multiply by the mask cast to float, then normalize by the mask’s sum rather than the buffer size.

def masked_cross_entropy(logits, labels, mask):
    # logits: (T, B, C)  labels: (T, B) int  mask: (T, B) bool
    # Use log-softmax for numerical stability.
    logp = jax.nn.log_softmax(logits, axis=-1)
    onehot = jax.nn.one_hot(labels, logp.shape[-1])
    nll = -jnp.sum(onehot * logp, axis=-1)           # (T, B)
    mask_f = mask.astype(nll.dtype)
    return jnp.sum(nll * mask_f) / jnp.maximum(jnp.sum(mask_f), 1.0)


# Toy demo: pretend a network produced uniform logits.
X, Y, M = task.batch_sample(4, return_mask=True)
logits = jnp.zeros(X.shape[:2] + (task.num_outputs,))
print('masked loss =', float(masked_cross_entropy(logits, Y, M)))
masked loss = 0.5730118751525879

The same pattern applies to vector-output tasks (use mse, cos_sin-style population loss, etc. and gate by mask[..., None]) and to accuracy metrics ((preds == Y) * mask divided by mask.sum()).

7. Duration samplers#

TruncExp and UniformDuration are the canonical helpers for drawing a per-trial duration in trial_init. They are JIT/vmap-safe (they consume ctx.rng) and they advertise the static bounds the framework needs to size buffers:

  • sampler.is_variable — class attribute set to True.

  • sampler.min_value(), sampler.max_value() — return the bounds as Quantities. VariableDuration.min_duration / max_duration should match (or sit just outside) these bounds.

te = TruncExp(mean=600 * u.ms, min_val=300 * u.ms, max_val=1500 * u.ms)
ud = UniformDuration(200 * u.ms, 800 * u.ms)

print('TruncExp        bounds:', te.min_value(), '-', te.max_value(),
      '   is_variable =', te.is_variable)
print('UniformDuration bounds:', ud.min_value(), '-', ud.max_value(),
      '   is_variable =', ud.is_variable)
TruncExp        bounds: 300 ms - 1500 ms    is_variable = True
UniformDuration bounds: 200 ms - 800 ms    is_variable = True
# Wire a sampler into trial_init -> VariableDuration:
delay_dist = TruncExp(mean=600 * u.ms, min_val=300 * u.ms, max_val=1500 * u.ms)

def init_with_sampler(ctx):
    # Store the sampled Quantity (or its mantissa) in the ctx key the
    # phase reads. Either form works.
    ctx['delay_duration'] = delay_dist(ctx).to(u.ms).mantissa
    ctx['gt']             = ctx.rng.choice(2).astype(jnp.int32) + 1

sampler_task = Task(
    phases=concat([
        Fixation(50 * u.ms, inputs={'fixation': 1.0}),
        VariableDuration(
            min_duration=delay_dist.min_value(),
            max_duration=delay_dist.max_value(),
            ctx_key='delay_duration',
            inputs={'fixation': 1.0},
        ),
        Response(50 * u.ms, outputs={'label': lambda c, f: c['gt']}),
    ]),
    input_features=fix + stim,
    output_features=fix + choice,
    trial_init=init_with_sampler,
    seed=0,
)

X, Y, M = sampler_task.batch_sample(8, return_mask=True)
print('T_max =', sampler_task.max_trial_duration(),
      '  per-trial lengths =', list(M.sum(axis=0)))
T_max = 1600   per-trial lengths = [Array(456, dtype=int32), Array(449, dtype=int32), Array(1020, dtype=int32), Array(1194, dtype=int32), Array(849, dtype=int32), Array(1588, dtype=int32), Array(635, dtype=int32), Array(407, dtype=int32)]

Summary#

  • Mark variable-length epochs with VariableDuration(min_duration, max_duration, ctx_key=...) and write the per-trial length into ctx[ctx_key] from trial_init.

  • The framework auto-detects variable-length trees and switches to a packed-buffer path. task.is_variable_length and task.max_trial_duration() describe the result.

  • Use batch_sample(B, return_mask=True) to get aligned (X, Y, mask) buffers under jit+vmap. The mask doubles as a per-step weight for losses and metrics.

  • If / Switch / While participate in the same buffers; conditional branches that didn’t run leave their slot at zero and mask=False.

  • TruncExp and UniformDuration are sampling helpers whose min_value() / max_value() line up with VariableDuration’s min_duration / max_duration.

For the full API surface — every phase type, encoder, label helper, duration sampler, and pre-built task — see the braintools.cogtask API reference.