{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7c3b7204",
   "metadata": {},
   "source": [
    "# Tutorial 1: Quickstart with `braintools.cogtask`\n",
    "\n",
    "`braintools.cogtask` is a modular, composable framework for constructing\n",
    "cognitive tasks for neural-network training and computational neuroscience\n",
    "simulations. This tutorial walks through the smallest end-to-end usage:\n",
    "\n",
    "1. Importing the module and setting the time step.\n",
    "2. Sampling a single trial from a pre-built task.\n",
    "3. Sampling a batch and inspecting the resulting tensors.\n",
    "4. Understanding the data layout (`X`, `Y`, `info`).\n",
    "\n",
    "---\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd64d002",
   "metadata": {},
   "source": [
    "## 1. Setup\n",
    "\n",
    "`cogtask` resolves all durations against the *currently active* time step\n",
    "(`brainstate.environ.get_dt()`). Set one before any sampling.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "93730534",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:01.851161Z",
     "iopub.status.busy": "2026-05-21T09:12:01.851011Z",
     "iopub.status.idle": "2026-05-21T09:12:02.267520Z",
     "shell.execute_reply": "2026-05-21T09:12:02.266430Z"
    }
   },
   "outputs": [],
   "source": [
    "import brainunit as u\n",
    "import brainstate\n",
    "import jax\n",
    "\n",
    "brainstate.environ.set(dt=1.0 * u.ms)\n",
    "\n",
    "from braintools import cogtask\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30e3bc29",
   "metadata": {},
   "source": [
    "## 2. A single trial from a pre-built task\n",
    "\n",
    "The pre-built tasks live under `braintools.cogtask` and are all subclasses of\n",
    "`Task`. Each task constructor exposes the standard cognitive-paradigm\n",
    "parameters (e.g. stimulus duration, number of choices, coherence levels) and\n",
    "accepts a `seed=` for reproducibility.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9131a21f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:02.270996Z",
     "iopub.status.busy": "2026-05-21T09:12:02.270553Z",
     "iopub.status.idle": "2026-05-21T09:12:02.373736Z",
     "shell.execute_reply": "2026-05-21T09:12:02.372898Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Task(name=PerceptualDecisionMaking, inputs=9, outputs=3, output_mode=categorical)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task = cogtask.PerceptualDecisionMaking(\n",
    "    t_stimulus=1500 * u.ms,\n",
    "    num_choices=2,\n",
    "    seed=0,\n",
    ")\n",
    "task\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bca947d5",
   "metadata": {},
   "source": [
    "`sample_trial(index)` returns `(X, Y, info)` for a single trial:\n",
    "\n",
    "- `X`: input tensor of shape `(T, num_inputs)`\n",
    "- `Y`: target tensor of shape `(T,)` in *categorical* mode, or\n",
    "  `(T, num_outputs)` in *vector* mode\n",
    "- `info`: a dict with `phase_history`, `trial_state`, `dt`, and `index`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f966f8a5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:02.375923Z",
     "iopub.status.busy": "2026-05-21T09:12:02.375701Z",
     "iopub.status.idle": "2026-05-21T09:12:04.708979Z",
     "shell.execute_reply": "2026-05-21T09:12:04.708170Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X.shape = (1700, 9)\n",
      "Y.shape = (1700,)\n",
      "num_inputs = 9\n",
      "num_outputs = 3\n"
     ]
    }
   ],
   "source": [
    "X, Y, info = task.sample_trial(0)\n",
    "print('X.shape =', X.shape)\n",
    "print('Y.shape =', Y.shape)\n",
    "print('num_inputs =', task.num_inputs)\n",
    "print('num_outputs =', task.num_outputs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7e0f44cf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:04.711030Z",
     "iopub.status.busy": "2026-05-21T09:12:04.710774Z",
     "iopub.status.idle": "2026-05-21T09:12:04.715070Z",
     "shell.execute_reply": "2026-05-21T09:12:04.714216Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fixation      [   0 :  100)\n",
      "stimulus      [ 100 : 1600)\n",
      "response      [1600 : 1700)\n",
      "Sequence      [1600 : 1700)\n"
     ]
    }
   ],
   "source": [
    "for name, start, end in info['phase_history']:\n",
    "    print(f'{name:<12s}  [{start:4d} : {end:4d})')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8b9d0d47",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:04.717181Z",
     "iopub.status.busy": "2026-05-21T09:12:04.716954Z",
     "iopub.status.idle": "2026-05-21T09:12:04.723637Z",
     "shell.execute_reply": "2026-05-21T09:12:04.722832Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'trial_index': 0,\n",
       " 'ground_truth': Array(0, dtype=int32),\n",
       " 'coherence': Array(25.6, dtype=float32),\n",
       " 'stimulus_direction': Array(0., dtype=float32),\n",
       " 'output_mode': 'categorical'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# trial_state captures whatever trial_init wrote into the context.\n",
    "info['trial_state']\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "504c4b30",
   "metadata": {},
   "source": [
    "## 3. Batched sampling for training loops\n",
    "\n",
    "`batch_sample(B)` is the JIT/`vmap`-compiled entry point used inside a\n",
    "training loop. By default it returns tensors with the time axis first\n",
    "(`time_first=True`), matching common RNN conventions:\n",
    "\n",
    "| call                                                | `X` shape                       | `Y` shape                              |\n",
    "|-----------------------------------------------------|---------------------------------|----------------------------------------|\n",
    "| `task.sample_trial(i)`                              | `(T, num_inputs)`               | `(T,)` or `(T, num_outputs)`           |\n",
    "| `task.batch_sample(B)`                              | `(T, B, num_inputs)`            | `(T, B)` or `(T, B, num_outputs)`      |\n",
    "| `task.batch_sample(B, time_first=False)`            | `(B, T, num_inputs)`            | `(B, T)` or `(B, T, num_outputs)`      |\n",
    "| `task.batch_sample(B, return_meta=True)`            | adds a third `meta` value       |                                        |\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0cd37409",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:04.726459Z",
     "iopub.status.busy": "2026-05-21T09:12:04.726159Z",
     "iopub.status.idle": "2026-05-21T09:12:06.472159Z",
     "shell.execute_reply": "2026-05-21T09:12:06.470826Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X.shape = (1700, 32, 9)\n",
      "Y.shape = (1700, 32)\n"
     ]
    }
   ],
   "source": [
    "X, Y = task.batch_sample(32)\n",
    "print('X.shape =', X.shape)\n",
    "print('Y.shape =', Y.shape)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43d9f0e7",
   "metadata": {},
   "source": [
    "### Reproducible batches\n",
    "\n",
    "When `Task` is constructed with `seed=N`, each trial in a batch uses\n",
    "`jax.random.fold_in(PRNGKey(N), start_index + i)` as its per-trial PRNG key.\n",
    "Two calls with the same `start_index` produce *bitwise identical* batches;\n",
    "two calls with different `start_index` produce non-overlapping batches.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7bf0a79d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:06.475699Z",
     "iopub.status.busy": "2026-05-21T09:12:06.475345Z",
     "iopub.status.idle": "2026-05-21T09:12:07.428476Z",
     "shell.execute_reply": "2026-05-21T09:12:07.427582Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "same start_index identical?   True\n",
      "next start_index differs?     True\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "X1, Y1 = task.batch_sample(8, start_index=0)\n",
    "X2, Y2 = task.batch_sample(8, start_index=0)\n",
    "X3, Y3 = task.batch_sample(8, start_index=8)\n",
    "\n",
    "print('same start_index identical?  ', np.array_equal(X1, X2))\n",
    "print('next start_index differs?    ', not np.array_equal(X1, X3))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "649334a3",
   "metadata": {},
   "source": [
    "### Streaming batches through training\n",
    "\n",
    "A typical training loop just walks `start_index` forward by `batch_size` at\n",
    "every step; the underlying `vmap` and JIT compilation are handled for you.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "146600b7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:07.431381Z",
     "iopub.status.busy": "2026-05-21T09:12:07.431013Z",
     "iopub.status.idle": "2026-05-21T09:12:08.542404Z",
     "shell.execute_reply": "2026-05-21T09:12:08.541418Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step  0  X.shape=(1700, 16, 9)  Y.shape=(1700, 16)\n",
      "step  1  X.shape=(1700, 16, 9)  Y.shape=(1700, 16)\n",
      "step  2  X.shape=(1700, 16, 9)  Y.shape=(1700, 16)\n",
      "step  3  X.shape=(1700, 16, 9)  Y.shape=(1700, 16)\n"
     ]
    }
   ],
   "source": [
    "batch_size = 16\n",
    "num_steps = 4\n",
    "for step in range(num_steps):\n",
    "    X, Y = task.batch_sample(batch_size, start_index=step * batch_size)\n",
    "    # train_step(model, X, Y)   # plug into your training loop\n",
    "    print(f'step {step:>2d}  X.shape={X.shape}  Y.shape={Y.shape}')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d31321d",
   "metadata": {},
   "source": [
    "## 4. Switching `dt` at sampling time\n",
    "\n",
    "The same task can be re-sampled at a finer or coarser time step simply by\n",
    "wrapping the sampling call in a `brainstate.environ.context`. Trial duration\n",
    "in real time stays fixed; only the number of timesteps `T` changes.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5712ffd4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:08.546026Z",
     "iopub.status.busy": "2026-05-21T09:12:08.545521Z",
     "iopub.status.idle": "2026-05-21T09:12:09.655772Z",
     "shell.execute_reply": "2026-05-21T09:12:09.654499Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dt=0.5 ms  ->  T = 3400\n",
      "dt=2.0 ms  ->  T = 3400\n"
     ]
    }
   ],
   "source": [
    "with brainstate.environ.context(dt=0.5 * u.ms):\n",
    "    X_fine, _ = task.batch_sample(4)\n",
    "\n",
    "with brainstate.environ.context(dt=2.0 * u.ms):\n",
    "    X_coarse, _ = task.batch_sample(4)\n",
    "\n",
    "print('dt=0.5 ms  ->  T =', X_fine.shape[0])\n",
    "print('dt=2.0 ms  ->  T =', X_coarse.shape[0])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b353bd1",
   "metadata": {},
   "source": [
    "## 5. Where to next\n",
    "\n",
    "- **Tutorial 2 — Building custom tasks**: phase composition, features,\n",
    "  encoders, label helpers, and class-based `Task` subclasses.\n",
    "- **Tutorial 3 — Variable-length trial sequences**: the current limits of\n",
    "  `batch_sample` on heterogeneous-length trials, today's workarounds, and\n",
    "  the planned padding-plus-mask API.\n",
    "- **API reference**: see `braintools.cogtask` in the API Reference section\n",
    "  for the full list of pre-built tasks, encoders, phases, and utilities.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
