Tutorial 1: Quickstart with braintools.cogtask#

braintools.cogtask is a modular, composable framework for constructing cognitive tasks for neural-network training and computational neuroscience simulations. This tutorial walks through the smallest end-to-end usage:

  1. Importing the module and setting the time step.

  2. Sampling a single trial from a pre-built task.

  3. Sampling a batch and inspecting the resulting tensors.

  4. Understanding the data layout (X, Y, info).


1. Setup#

cogtask resolves all durations against the currently active time step (brainstate.environ.get_dt()). Set one before any sampling.

import brainunit as u
import brainstate
import jax

brainstate.environ.set(dt=1.0 * u.ms)

from braintools import cogtask

2. A single trial from a pre-built task#

The pre-built tasks live under braintools.cogtask and are all subclasses of Task. Each task constructor exposes the standard cognitive-paradigm parameters (e.g. stimulus duration, number of choices, coherence levels) and accepts a seed= for reproducibility.

task = cogtask.PerceptualDecisionMaking(
    t_stimulus=1500 * u.ms,
    num_choices=2,
    seed=0,
)
task
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Task(name=PerceptualDecisionMaking, inputs=9, outputs=3, output_mode=categorical)

sample_trial(index) returns (X, Y, info) for a single trial:

  • X: input tensor of shape (T, num_inputs)

  • Y: target tensor of shape (T,) in categorical mode, or (T, num_outputs) in vector mode

  • info: a dict with phase_history, trial_state, dt, and index

X, Y, info = task.sample_trial(0)
print('X.shape =', X.shape)
print('Y.shape =', Y.shape)
print('num_inputs =', task.num_inputs)
print('num_outputs =', task.num_outputs)
X.shape = (1700, 9)
Y.shape = (1700,)
num_inputs = 9
num_outputs = 3
for name, start, end in info['phase_history']:
    print(f'{name:<12s}  [{start:4d} : {end:4d})')
fixation      [   0 :  100)
stimulus      [ 100 : 1600)
response      [1600 : 1700)
Sequence      [1600 : 1700)
# trial_state captures whatever trial_init wrote into the context.
info['trial_state']
{'trial_index': 0,
 'ground_truth': Array(0, dtype=int32),
 'coherence': Array(25.6, dtype=float32),
 'stimulus_direction': Array(0., dtype=float32),
 'output_mode': 'categorical'}

3. Batched sampling for training loops#

batch_sample(B) is the JIT/vmap-compiled entry point used inside a training loop. By default it returns tensors with the time axis first (time_first=True), matching common RNN conventions:

call

X shape

Y shape

task.sample_trial(i)

(T, num_inputs)

(T,) or (T, num_outputs)

task.batch_sample(B)

(T, B, num_inputs)

(T, B) or (T, B, num_outputs)

task.batch_sample(B, time_first=False)

(B, T, num_inputs)

(B, T) or (B, T, num_outputs)

task.batch_sample(B, return_meta=True)

adds a third meta value

X, Y = task.batch_sample(32)
print('X.shape =', X.shape)
print('Y.shape =', Y.shape)
X.shape = (1700, 32, 9)
Y.shape = (1700, 32)

Reproducible batches#

When Task is constructed with seed=N, each trial in a batch uses jax.random.fold_in(PRNGKey(N), start_index + i) as its per-trial PRNG key. Two calls with the same start_index produce bitwise identical batches; two calls with different start_index produce non-overlapping batches.

import numpy as np

X1, Y1 = task.batch_sample(8, start_index=0)
X2, Y2 = task.batch_sample(8, start_index=0)
X3, Y3 = task.batch_sample(8, start_index=8)

print('same start_index identical?  ', np.array_equal(X1, X2))
print('next start_index differs?    ', not np.array_equal(X1, X3))
same start_index identical?   True
next start_index differs?     True

Streaming batches through training#

A typical training loop just walks start_index forward by batch_size at every step; the underlying vmap and JIT compilation are handled for you.

batch_size = 16
num_steps = 4
for step in range(num_steps):
    X, Y = task.batch_sample(batch_size, start_index=step * batch_size)
    # train_step(model, X, Y)   # plug into your training loop
    print(f'step {step:>2d}  X.shape={X.shape}  Y.shape={Y.shape}')
step  0  X.shape=(1700, 16, 9)  Y.shape=(1700, 16)
step  1  X.shape=(1700, 16, 9)  Y.shape=(1700, 16)
step  2  X.shape=(1700, 16, 9)  Y.shape=(1700, 16)
step  3  X.shape=(1700, 16, 9)  Y.shape=(1700, 16)

4. Switching dt at sampling time#

The same task can be re-sampled at a finer or coarser time step simply by wrapping the sampling call in a brainstate.environ.context. Trial duration in real time stays fixed; only the number of timesteps T changes.

with brainstate.environ.context(dt=0.5 * u.ms):
    X_fine, _ = task.batch_sample(4)

with brainstate.environ.context(dt=2.0 * u.ms):
    X_coarse, _ = task.batch_sample(4)

print('dt=0.5 ms  ->  T =', X_fine.shape[0])
print('dt=2.0 ms  ->  T =', X_coarse.shape[0])
dt=0.5 ms  ->  T = 3400
dt=2.0 ms  ->  T = 3400

5. Where to next#

  • Tutorial 2 — Building custom tasks: phase composition, features, encoders, label helpers, and class-based Task subclasses.

  • Tutorial 3 — Variable-length trial sequences: the current limits of batch_sample on heterogeneous-length trials, today’s workarounds, and the planned padding-plus-mask API.

  • API reference: see braintools.cogtask in the API Reference section for the full list of pre-built tasks, encoders, phases, and utilities.