{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f8e9f285",
   "metadata": {},
   "source": [
    "# Tutorial 3: Variable-length trial sequences\n",
    "\n",
    "Real cognitive experiments rarely use a single fixed timeline. Delay\n",
    "periods are jittered to discourage timing strategies; decisions are made\n",
    "when evidence reaches a threshold; trials branch on a cue. This tutorial\n",
    "shows how `braintools.cogtask` supports such trials end-to-end under\n",
    "`jit` and `vmap`.\n",
    "\n",
    "The framework uses a **packed-buffer + mask** design:\n",
    "\n",
    "- Every trial in a batch is written into a buffer sized to the worst-case\n",
    "  length, `task.max_trial_duration()`.\n",
    "- Each trial reports its actual length via a boolean `mask` of shape\n",
    "  `(T_max,)`. `True` marks live timesteps, `False` marks padding.\n",
    "- Phases write only into their valid slice; trailing positions stay at\n",
    "  zero in `X`/`Y` and `False` in `mask`.\n",
    "- Buffer shapes are static Python ints, so `brainstate.transform.jit` and\n",
    "  `vmap2` work without retracing.\n",
    "\n",
    "This tutorial covers:\n",
    "\n",
    "1. The `VariableDuration` phase and how it's used in `trial_init`.\n",
    "2. Detecting variable-length tasks (`task.is_variable_length`,\n",
    "   `phase_tree_is_variable`, `task.max_trial_duration()`).\n",
    "3. Sampling with masks (`sample_trial`, `batch_sample(return_mask=True)`).\n",
    "4. Conditional control flow (`If`, `Switch`, `While`) under packed mode.\n",
    "5. The migrated built-in tasks (`HierarchicalReasoning`,\n",
    "   `IntervalDiscrimination`, `ReadySetGo`).\n",
    "6. Consuming the mask in losses and metrics.\n",
    "7. Duration samplers (`TruncExp`, `UniformDuration`) and their bounds.\n",
    "\n",
    "---\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e83eace2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:47:57.524063Z",
     "iopub.status.busy": "2026-05-21T12:47:57.523739Z",
     "iopub.status.idle": "2026-05-21T12:48:01.052779Z",
     "shell.execute_reply": "2026-05-21T12:48:01.051799Z"
    }
   },
   "outputs": [],
   "source": [
    "import brainstate\n",
    "import brainunit as u\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "brainstate.environ.set(dt=1.0 * u.ms)\n",
    "\n",
    "from braintools.cogtask import (\n",
    "    Task, Feature, concat,\n",
    "    Fixation, Delay, Response, Stimulus, Sample, Test,\n",
    "    VariableDuration,\n",
    "    If, Switch, While,\n",
    "    TruncExp, UniformDuration,\n",
    "    phase_tree_is_variable,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be27436e",
   "metadata": {},
   "source": [
    "## 1. The `VariableDuration` phase\n",
    "\n",
    "`VariableDuration` is the declarative primitive for \"this phase lasts a\n",
    "trial-dependent number of steps.\" It looks like any other declarative\n",
    "phase (it takes `inputs=`, `outputs=`, `noise=`), but its duration is\n",
    "read from the trial context at sample time:\n",
    "\n",
    "- `min_duration` / `max_duration` are `brainunit` quantities. They bound\n",
    "  the phase: `min_duration` floors the step count and `max_duration` is\n",
    "  the **static upper bound** used to size the buffer slot.\n",
    "- `ctx_key` names the trial-state entry holding the actual duration for\n",
    "  this trial. `trial_init` writes a scalar (a float in `dt` units or a\n",
    "  Quantity) into `ctx[ctx_key]`.\n",
    "\n",
    "The framework converts `ctx[ctx_key]` to a step count, clips it into\n",
    "`[1, ceil(max_duration / dt)]`, and writes only that many timesteps.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ba87acfd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:01.055031Z",
     "iopub.status.busy": "2026-05-21T12:48:01.054676Z",
     "iopub.status.idle": "2026-05-21T12:48:01.060457Z",
     "shell.execute_reply": "2026-05-21T12:48:01.059602Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeclarativePhase('variable_delay', inputs=['fixation'])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "variable_delay = VariableDuration(\n",
    "    min_duration=200 * u.ms,\n",
    "    max_duration=1500 * u.ms,\n",
    "    ctx_key='delay_duration',\n",
    "    inputs={'fixation': 1.0},\n",
    "    outputs={'label': 0},\n",
    "    name='variable_delay',\n",
    ")\n",
    "variable_delay"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "066e2d26",
   "metadata": {},
   "source": [
    "### A minimal variable-delay task\n",
    "\n",
    "A delay-match-sample task with a delay drawn uniformly per trial:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3c0219bf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:01.062389Z",
     "iopub.status.busy": "2026-05-21T12:48:01.062121Z",
     "iopub.status.idle": "2026-05-21T12:48:01.164625Z",
     "shell.execute_reply": "2026-05-21T12:48:01.163513Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Task(name=Task, inputs=3, outputs=3, output_mode=categorical)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fix    = Feature(1, 'fixation')\n",
    "stim   = Feature(2, 'stim')\n",
    "choice = Feature(2, 'choice')\n",
    "\n",
    "phases = concat([\n",
    "    Fixation(50 * u.ms, inputs={'fixation': 1.0}, outputs={'label': 0}),\n",
    "    Sample(40 * u.ms,\n",
    "           inputs={'stim': lambda c, f: jnp.ones(f.num)},\n",
    "           outputs={'label': 0}),\n",
    "    VariableDuration(\n",
    "        min_duration=200 * u.ms,\n",
    "        max_duration=1500 * u.ms,\n",
    "        ctx_key='delay_duration',\n",
    "        inputs={'fixation': 1.0},\n",
    "        outputs={'label': 0},\n",
    "        name='delay',\n",
    "    ),\n",
    "    Response(50 * u.ms,\n",
    "             outputs={'label': lambda c, f: c['gt']}),\n",
    "])\n",
    "\n",
    "def init(ctx):\n",
    "    # Both values are JAX scalars — trial_init runs under jit/vmap.\n",
    "    ctx['delay_duration'] = ctx.rng.uniform(200.0, 1500.0)\n",
    "    ctx['gt']             = ctx.rng.choice(2).astype(jnp.int32) + 1\n",
    "\n",
    "task = Task(\n",
    "    phases=phases,\n",
    "    input_features=fix + stim,\n",
    "    output_features=fix + choice,\n",
    "    trial_init=init,\n",
    "    seed=0,\n",
    ")\n",
    "task"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b37b178f",
   "metadata": {},
   "source": [
    "## 2. Detecting variable-length tasks\n",
    "\n",
    "`Task` walks the phase tree once at construction time and records\n",
    "whether *any* phase advertises `is_variable = True`. Three properties\n",
    "let you reason about the resulting buffers.\n",
    "\n",
    "- `task.is_variable_length` — `True` if the task uses the packed path.\n",
    "- `task.max_trial_duration()` — Python `int`, the worst-case timestep\n",
    "  count. This is the static `T` used by `sample_trial` and\n",
    "  `batch_sample`. Safe as a buffer dimension under `jit`/`vmap`.\n",
    "- `phase_tree_is_variable(phases)` — module-level helper that walks any\n",
    "  phase subtree; useful when composing trees outside a `Task`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "af64ab51",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:01.166437Z",
     "iopub.status.busy": "2026-05-21T12:48:01.166198Z",
     "iopub.status.idle": "2026-05-21T12:48:01.171169Z",
     "shell.execute_reply": "2026-05-21T12:48:01.170158Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is_variable_length    : True\n",
      "max_trial_duration()  : 1640 steps at dt=1ms\n",
      "phase_tree_is_variable: True\n"
     ]
    }
   ],
   "source": [
    "print('is_variable_length    :', task.is_variable_length)\n",
    "print('max_trial_duration()  :', task.max_trial_duration(), 'steps at dt=1ms')\n",
    "# 50 + 40 + 1500 + 50 = 1640\n",
    "\n",
    "# The same detection at the phase-tree level:\n",
    "print('phase_tree_is_variable:', phase_tree_is_variable(phases))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de452399",
   "metadata": {},
   "source": [
    "## 3. Sampling with masks\n",
    "\n",
    "For variable-length tasks, `sample_trial` returns the usual `(X, Y, info)`\n",
    "triple, but `info['mask']` is now a `(T_max,)` boolean array. For\n",
    "fixed-length tasks `info['mask']` is `None`.\n",
    "\n",
    "`batch_sample(B, return_mask=True)` is the JIT/vmap path: it returns\n",
    "`(X, Y, mask)` with `mask` in the same time/batch layout as `X` and `Y`\n",
    "(default time-first → `(T_max, B)`; pass `time_first=False` for\n",
    "`(B, T_max)`).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1d6052af",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:01.174032Z",
     "iopub.status.busy": "2026-05-21T12:48:01.173777Z",
     "iopub.status.idle": "2026-05-21T12:48:03.905191Z",
     "shell.execute_reply": "2026-05-21T12:48:03.904412Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X.shape    = (1640, 3)   (T_max, num_inputs)\n",
      "Y.shape    = (1640,)   (T_max,)\n",
      "mask.shape = (1640,)  mask.dtype = bool\n",
      "valid steps for trial 0 = 475\n",
      "full T_max              = 1640\n"
     ]
    }
   ],
   "source": [
    "X, Y, info = task.sample_trial(0)\n",
    "mask = info['mask']\n",
    "print('X.shape    =', X.shape, '  (T_max, num_inputs)')\n",
    "print('Y.shape    =', Y.shape, '  (T_max,)')\n",
    "print('mask.shape =', mask.shape, ' mask.dtype =', mask.dtype)\n",
    "print('valid steps for trial 0 =', int(jnp.sum(mask)))\n",
    "print('full T_max              =', X.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bbb53a62",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:03.907429Z",
     "iopub.status.busy": "2026-05-21T12:48:03.907103Z",
     "iopub.status.idle": "2026-05-21T12:48:04.563097Z",
     "shell.execute_reply": "2026-05-21T12:48:04.562190Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X.shape    = (1640, 8, 3)   (T_max, B, num_inputs)\n",
      "Y.shape    = (1640, 8)   (T_max, B)\n",
      "mask.shape = (1640, 8)  (T_max, B)\n",
      "per-trial valid lengths = [ 475  458 1309 1443 1132 1636  827  358]\n"
     ]
    }
   ],
   "source": [
    "# JIT + vmap path: stack 8 trials with heterogeneous delay lengths.\n",
    "X, Y, mask = task.batch_sample(8, return_mask=True)\n",
    "print('X.shape    =', X.shape, '  (T_max, B, num_inputs)')\n",
    "print('Y.shape    =', Y.shape, '  (T_max, B)')\n",
    "print('mask.shape =', mask.shape, ' (T_max, B)')\n",
    "\n",
    "# Each column of mask records that trial's length:\n",
    "lengths = mask.sum(axis=0)\n",
    "print('per-trial valid lengths =', lengths)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac81de1b",
   "metadata": {},
   "source": [
    "### Shape contract\n",
    "\n",
    "| Call                                            | Returns                | Shapes                                                              |\n",
    "|-------------------------------------------------|------------------------|---------------------------------------------------------------------|\n",
    "| `task.sample_trial(i)`                          | `(X, Y, info)`         | `X: (T_max, F)`, `Y: (T_max,)` or `(T_max, F_out)`, `info['mask']: (T_max,)` bool |\n",
    "| `task.batch_sample(B)`                          | `(X, Y)`               | `X: (T_max, B, F)`, `Y: (T_max, B)` or `(T_max, B, F_out)`         |\n",
    "| `task.batch_sample(B, return_mask=True)`        | `(X, Y, mask)`         | as above, plus `mask: (T_max, B)` bool                              |\n",
    "| `task.batch_sample(B, time_first=False, return_mask=True)` | `(X, Y, mask)` | `X: (B, T_max, F)`, `Y: (B, T_max)`, `mask: (B, T_max)` |\n",
    "\n",
    "`return_mask=True` is supported on **fixed-length** tasks too — the\n",
    "returned mask is simply all-`True`. That means a downstream training\n",
    "loop can call `batch_sample(..., return_mask=True)` unconditionally\n",
    "without branching on `task.is_variable_length`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f69c8d9",
   "metadata": {},
   "source": [
    "### Trailing positions are zero\n",
    "\n",
    "Phases write only into their valid slice. Anything past a trial's\n",
    "actual length stays at the buffer default (`0` in `X`, `0` in `Y`,\n",
    "`False` in `mask`).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a1e8be10",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:04.565071Z",
     "iopub.status.busy": "2026-05-21T12:48:04.564893Z",
     "iopub.status.idle": "2026-05-21T12:48:04.741598Z",
     "shell.execute_reply": "2026-05-21T12:48:04.740784Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trial 0: 475 valid steps, 1165 padded steps\n",
      "tail of X  : [[0. 0. 0.]\n",
      " [0. 0. 0.]\n",
      " [0. 0. 0.]\n",
      " [0. 0. 0.]]\n",
      "tail of Y  : [0 0 0 0]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tail mask  : [False False False False]\n"
     ]
    }
   ],
   "source": [
    "# Pull a single trial back out of the batch and look at its tail.\n",
    "trial_idx = 0\n",
    "valid = int(mask[:, trial_idx].sum())\n",
    "print(f'trial {trial_idx}: {valid} valid steps, {X.shape[0] - valid} padded steps')\n",
    "print('tail of X  :', X[valid:valid + 4, trial_idx])\n",
    "print('tail of Y  :', Y[valid:valid + 4, trial_idx])\n",
    "print('tail mask  :', mask[valid:valid + 4, trial_idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "078481b7",
   "metadata": {},
   "source": [
    "## 4. Conditional control flow under packed mode\n",
    "\n",
    "`If`, `Switch`, and `While` all participate in variable-length trees.\n",
    "Their `max_steps` is the static upper bound the framework allocates for;\n",
    "their `step_count` reports what actually ran on this trial.\n",
    "\n",
    "| Phase     | Semantics under packed mode                                                              | Constraints                                                                                             |\n",
    "|-----------|-------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------|\n",
    "| `If`      | Both branches contribute via `jax.lax.cond`. Buffer is sized to `max(then_max, else_max)`. | The predicate must read trial state (`ctx[...]`); it may be a JAX tracer.                              |\n",
    "| `Switch`  | Python-level dispatch on the selector's value.                                            | Selector must return a **hashable Python key**, not a tracer (set it in `trial_init` as a Python value). |\n",
    "| `While`   | Python-level loop bounded by `max_iterations`.                                            | Condition must return a Python `bool`. Buffer = `body.max_steps * max_iterations`.                      |\n",
    "\n",
    "Branches that don't run leave their buffer region at zero and do **not**\n",
    "advance `t_cursor`, so the mask stays `False` for the unused slot.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "91eb2625",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:04.743292Z",
     "iopub.status.busy": "2026-05-21T12:48:04.743115Z",
     "iopub.status.idle": "2026-05-21T12:48:05.372050Z",
     "shell.execute_reply": "2026-05-21T12:48:05.371198Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is_variable_length  : True\n",
      "max_trial_duration  : 80 (20 + max(40, 40) + 20)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mask sums per trial : [80 80 80 80 80 80 80 80]\n"
     ]
    }
   ],
   "source": [
    "# If: cued go/no-go where 'go' is sampled per trial.\n",
    "go_or_nogo = concat([\n",
    "    Fixation(20 * u.ms, inputs={'fixation': 1.0}),\n",
    "    If(\n",
    "        condition=lambda ctx: ctx['go'],\n",
    "        then=Stimulus(40 * u.ms,\n",
    "                      inputs={'stim': lambda c, f: jnp.ones(f.num)}),\n",
    "        else_=Fixation(40 * u.ms,\n",
    "                       inputs={'fixation': 0.5}),\n",
    "    ),\n",
    "    Response(20 * u.ms, outputs={'label': lambda c, f: c['gt']}),\n",
    "])\n",
    "\n",
    "def go_init(ctx):\n",
    "    ctx['go'] = ctx.rng.choice(2).astype(jnp.bool_)\n",
    "    ctx['gt'] = ctx.rng.choice(2).astype(jnp.int32) + 1\n",
    "\n",
    "go_task = Task(\n",
    "    phases=go_or_nogo,\n",
    "    input_features=fix + stim,\n",
    "    output_features=fix + choice,\n",
    "    trial_init=go_init,\n",
    "    seed=0,\n",
    ")\n",
    "\n",
    "print('is_variable_length  :', go_task.is_variable_length)\n",
    "print('max_trial_duration  :', go_task.max_trial_duration(),\n",
    "      '(20 + max(40, 40) + 20)')\n",
    "\n",
    "X, Y, M = go_task.batch_sample(8, return_mask=True)\n",
    "print('mask sums per trial :', M.sum(axis=0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa4f608d",
   "metadata": {},
   "source": [
    "## 5. Migrated built-in tasks\n",
    "\n",
    "Three pre-built tasks in `braintools.cogtask` now use `VariableDuration`\n",
    "internally and work with `batch_sample` out of the box:\n",
    "\n",
    "- `HierarchicalReasoning` — variable delay between two flash cues.\n",
    "- `IntervalDiscrimination` — two stimulus intervals sampled\n",
    "  independently per trial.\n",
    "- `ReadySetGo` — measurement interval sampled per trial.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b172c03a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:05.373778Z",
     "iopub.status.busy": "2026-05-21T12:48:05.373610Z",
     "iopub.status.idle": "2026-05-21T12:48:04.782886Z",
     "shell.execute_reply": "2026-05-21T12:48:04.781760Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HierarchicalReasoning     T_max= 2000  variable=True  mask sums=[Array(1649, dtype=int32), Array(1489, dtype=int32), Array(1537, dtype=int32), Array(1862, dtype=int32)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "IntervalDiscrimination    T_max= 3100  variable=True  mask sums=[Array(2604, dtype=int32), Array(2634, dtype=int32), Array(2646, dtype=int32), Array(2778, dtype=int32)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ReadySetGo                T_max= 2800  variable=True  mask sums=[Array(2566, dtype=int32), Array(2459, dtype=int32), Array(2491, dtype=int32), Array(2708, dtype=int32)]\n"
     ]
    }
   ],
   "source": [
    "from braintools.cogtask import (\n",
    "    HierarchicalReasoning, IntervalDiscrimination, ReadySetGo,\n",
    ")\n",
    "\n",
    "for cls in [HierarchicalReasoning, IntervalDiscrimination, ReadySetGo]:\n",
    "    task = cls(seed=42)\n",
    "    X, Y, M = task.batch_sample(4, return_mask=True)\n",
    "    print(f'{cls.__name__:25s} T_max={task.max_trial_duration():5d}'\n",
    "          f'  variable={task.is_variable_length}'\n",
    "          f'  mask sums={list(M.sum(axis=0))}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1194521",
   "metadata": {},
   "source": [
    "## 6. Using the mask in losses and metrics\n",
    "\n",
    "The standard recipe for a masked cross-entropy loss: compute the loss\n",
    "elementwise, multiply by the mask cast to float, then normalize by the\n",
    "mask's sum rather than the buffer size.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1c525e11",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:04.785001Z",
     "iopub.status.busy": "2026-05-21T12:48:04.784742Z",
     "iopub.status.idle": "2026-05-21T12:48:05.223460Z",
     "shell.execute_reply": "2026-05-21T12:48:05.222017Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "masked loss = 0.5730118751525879\n"
     ]
    }
   ],
   "source": [
    "def masked_cross_entropy(logits, labels, mask):\n",
    "    # logits: (T, B, C)  labels: (T, B) int  mask: (T, B) bool\n",
    "    # Use log-softmax for numerical stability.\n",
    "    logp = jax.nn.log_softmax(logits, axis=-1)\n",
    "    onehot = jax.nn.one_hot(labels, logp.shape[-1])\n",
    "    nll = -jnp.sum(onehot * logp, axis=-1)           # (T, B)\n",
    "    mask_f = mask.astype(nll.dtype)\n",
    "    return jnp.sum(nll * mask_f) / jnp.maximum(jnp.sum(mask_f), 1.0)\n",
    "\n",
    "\n",
    "# Toy demo: pretend a network produced uniform logits.\n",
    "X, Y, M = task.batch_sample(4, return_mask=True)\n",
    "logits = jnp.zeros(X.shape[:2] + (task.num_outputs,))\n",
    "print('masked loss =', float(masked_cross_entropy(logits, Y, M)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd32396c",
   "metadata": {},
   "source": [
    "The same pattern applies to vector-output tasks (use `mse`,\n",
    "`cos_sin`-style population loss, etc. and gate by `mask[..., None]`) and\n",
    "to accuracy metrics (`(preds == Y) * mask` divided by `mask.sum()`).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "143a1c78",
   "metadata": {},
   "source": [
    "## 7. Duration samplers\n",
    "\n",
    "`TruncExp` and `UniformDuration` are the canonical helpers for drawing a\n",
    "per-trial duration in `trial_init`. They are JIT/`vmap`-safe (they\n",
    "consume `ctx.rng`) and they advertise the static bounds the framework\n",
    "needs to size buffers:\n",
    "\n",
    "- `sampler.is_variable` — class attribute set to `True`.\n",
    "- `sampler.min_value()`, `sampler.max_value()` — return the bounds as\n",
    "  Quantities. `VariableDuration.min_duration` / `max_duration` should\n",
    "  match (or sit just outside) these bounds.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b7d56516",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:05.226982Z",
     "iopub.status.busy": "2026-05-21T12:48:05.226616Z",
     "iopub.status.idle": "2026-05-21T12:48:05.235804Z",
     "shell.execute_reply": "2026-05-21T12:48:05.234801Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TruncExp        bounds: 300 ms - 1500 ms    is_variable = True\n",
      "UniformDuration bounds: 200 ms - 800 ms    is_variable = True\n"
     ]
    }
   ],
   "source": [
    "te = TruncExp(mean=600 * u.ms, min_val=300 * u.ms, max_val=1500 * u.ms)\n",
    "ud = UniformDuration(200 * u.ms, 800 * u.ms)\n",
    "\n",
    "print('TruncExp        bounds:', te.min_value(), '-', te.max_value(),\n",
    "      '   is_variable =', te.is_variable)\n",
    "print('UniformDuration bounds:', ud.min_value(), '-', ud.max_value(),\n",
    "      '   is_variable =', ud.is_variable)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a0c0ac1a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T12:48:05.238253Z",
     "iopub.status.busy": "2026-05-21T12:48:05.237936Z",
     "iopub.status.idle": "2026-05-21T12:48:05.839416Z",
     "shell.execute_reply": "2026-05-21T12:48:05.838172Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T_max = 1600   per-trial lengths = [Array(456, dtype=int32), Array(449, dtype=int32), Array(1020, dtype=int32), Array(1194, dtype=int32), Array(849, dtype=int32), Array(1588, dtype=int32), Array(635, dtype=int32), Array(407, dtype=int32)]\n"
     ]
    }
   ],
   "source": [
    "# Wire a sampler into trial_init -> VariableDuration:\n",
    "delay_dist = TruncExp(mean=600 * u.ms, min_val=300 * u.ms, max_val=1500 * u.ms)\n",
    "\n",
    "def init_with_sampler(ctx):\n",
    "    # Store the sampled Quantity (or its mantissa) in the ctx key the\n",
    "    # phase reads. Either form works.\n",
    "    ctx['delay_duration'] = delay_dist(ctx).to(u.ms).mantissa\n",
    "    ctx['gt']             = ctx.rng.choice(2).astype(jnp.int32) + 1\n",
    "\n",
    "sampler_task = Task(\n",
    "    phases=concat([\n",
    "        Fixation(50 * u.ms, inputs={'fixation': 1.0}),\n",
    "        VariableDuration(\n",
    "            min_duration=delay_dist.min_value(),\n",
    "            max_duration=delay_dist.max_value(),\n",
    "            ctx_key='delay_duration',\n",
    "            inputs={'fixation': 1.0},\n",
    "        ),\n",
    "        Response(50 * u.ms, outputs={'label': lambda c, f: c['gt']}),\n",
    "    ]),\n",
    "    input_features=fix + stim,\n",
    "    output_features=fix + choice,\n",
    "    trial_init=init_with_sampler,\n",
    "    seed=0,\n",
    ")\n",
    "\n",
    "X, Y, M = sampler_task.batch_sample(8, return_mask=True)\n",
    "print('T_max =', sampler_task.max_trial_duration(),\n",
    "      '  per-trial lengths =', list(M.sum(axis=0)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5fb492a8",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- Mark variable-length epochs with `VariableDuration(min_duration,\n",
    "  max_duration, ctx_key=...)` and write the per-trial length into\n",
    "  `ctx[ctx_key]` from `trial_init`.\n",
    "- The framework auto-detects variable-length trees and switches to a\n",
    "  packed-buffer path. `task.is_variable_length` and\n",
    "  `task.max_trial_duration()` describe the result.\n",
    "- Use `batch_sample(B, return_mask=True)` to get aligned `(X, Y, mask)`\n",
    "  buffers under `jit`+`vmap`. The mask doubles as a per-step weight for\n",
    "  losses and metrics.\n",
    "- `If` / `Switch` / `While` participate in the same buffers; conditional\n",
    "  branches that didn't run leave their slot at zero and `mask=False`.\n",
    "- `TruncExp` and `UniformDuration` are sampling helpers whose\n",
    "  `min_value()` / `max_value()` line up with `VariableDuration`'s\n",
    "  `min_duration` / `max_duration`.\n",
    "\n",
    "For the full API surface — every phase type, encoder, label helper,\n",
    "duration sampler, and pre-built task — see the\n",
    "`braintools.cogtask` API reference.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
