Tutorial 1: Quickstart with braintools.cogtask#
braintools.cogtask is a modular, composable framework for constructing
cognitive tasks for neural-network training and computational neuroscience
simulations. This tutorial walks through the smallest end-to-end usage:
Importing the module and setting the time step.
Sampling a single trial from a pre-built task.
Sampling a batch and inspecting the resulting tensors.
Understanding the data layout (
X,Y,info).
1. Setup#
cogtask resolves all durations against the currently active time step
(brainstate.environ.get_dt()). Set one before any sampling.
import brainunit as u
import brainstate
import jax
brainstate.environ.set(dt=1.0 * u.ms)
from braintools import cogtask
2. A single trial from a pre-built task#
The pre-built tasks live under braintools.cogtask and are all subclasses of
Task. Each task constructor exposes the standard cognitive-paradigm
parameters (e.g. stimulus duration, number of choices, coherence levels) and
accepts a seed= for reproducibility.
task = cogtask.PerceptualDecisionMaking(
t_stimulus=1500 * u.ms,
num_choices=2,
seed=0,
)
task
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Task(name=PerceptualDecisionMaking, inputs=9, outputs=3, output_mode=categorical)
sample_trial(index) returns (X, Y, info) for a single trial:
X: input tensor of shape(T, num_inputs)Y: target tensor of shape(T,)in categorical mode, or(T, num_outputs)in vector modeinfo: a dict withphase_history,trial_state,dt, andindex
X, Y, info = task.sample_trial(0)
print('X.shape =', X.shape)
print('Y.shape =', Y.shape)
print('num_inputs =', task.num_inputs)
print('num_outputs =', task.num_outputs)
X.shape = (1700, 9)
Y.shape = (1700,)
num_inputs = 9
num_outputs = 3
for name, start, end in info['phase_history']:
print(f'{name:<12s} [{start:4d} : {end:4d})')
fixation [ 0 : 100)
stimulus [ 100 : 1600)
response [1600 : 1700)
Sequence [1600 : 1700)
# trial_state captures whatever trial_init wrote into the context.
info['trial_state']
{'trial_index': 0,
'ground_truth': Array(0, dtype=int32),
'coherence': Array(25.6, dtype=float32),
'stimulus_direction': Array(0., dtype=float32),
'output_mode': 'categorical'}
3. Batched sampling for training loops#
batch_sample(B) is the JIT/vmap-compiled entry point used inside a
training loop. By default it returns tensors with the time axis first
(time_first=True), matching common RNN conventions:
call |
|
|
|---|---|---|
|
|
|
|
|
|
|
|
|
|
adds a third |
X, Y = task.batch_sample(32)
print('X.shape =', X.shape)
print('Y.shape =', Y.shape)
X.shape = (1700, 32, 9)
Y.shape = (1700, 32)
Reproducible batches#
When Task is constructed with seed=N, each trial in a batch uses
jax.random.fold_in(PRNGKey(N), start_index + i) as its per-trial PRNG key.
Two calls with the same start_index produce bitwise identical batches;
two calls with different start_index produce non-overlapping batches.
import numpy as np
X1, Y1 = task.batch_sample(8, start_index=0)
X2, Y2 = task.batch_sample(8, start_index=0)
X3, Y3 = task.batch_sample(8, start_index=8)
print('same start_index identical? ', np.array_equal(X1, X2))
print('next start_index differs? ', not np.array_equal(X1, X3))
same start_index identical? True
next start_index differs? True
Streaming batches through training#
A typical training loop just walks start_index forward by batch_size at
every step; the underlying vmap and JIT compilation are handled for you.
batch_size = 16
num_steps = 4
for step in range(num_steps):
X, Y = task.batch_sample(batch_size, start_index=step * batch_size)
# train_step(model, X, Y) # plug into your training loop
print(f'step {step:>2d} X.shape={X.shape} Y.shape={Y.shape}')
step 0 X.shape=(1700, 16, 9) Y.shape=(1700, 16)
step 1 X.shape=(1700, 16, 9) Y.shape=(1700, 16)
step 2 X.shape=(1700, 16, 9) Y.shape=(1700, 16)
step 3 X.shape=(1700, 16, 9) Y.shape=(1700, 16)
4. Switching dt at sampling time#
The same task can be re-sampled at a finer or coarser time step simply by
wrapping the sampling call in a brainstate.environ.context. Trial duration
in real time stays fixed; only the number of timesteps T changes.
with brainstate.environ.context(dt=0.5 * u.ms):
X_fine, _ = task.batch_sample(4)
with brainstate.environ.context(dt=2.0 * u.ms):
X_coarse, _ = task.batch_sample(4)
print('dt=0.5 ms -> T =', X_fine.shape[0])
print('dt=2.0 ms -> T =', X_coarse.shape[0])
dt=0.5 ms -> T = 3400
dt=2.0 ms -> T = 3400
5. Where to next#
Tutorial 2 — Building custom tasks: phase composition, features, encoders, label helpers, and class-based
Tasksubclasses.Tutorial 3 — Variable-length trial sequences: the current limits of
batch_sampleon heterogeneous-length trials, today’s workarounds, and the planned padding-plus-mask API.API reference: see
braintools.cogtaskin the API Reference section for the full list of pre-built tasks, encoders, phases, and utilities.