{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fd43a5a1",
   "metadata": {},
   "source": [
    "# Tutorial 2: Building custom tasks\n",
    "\n",
    "This tutorial covers the building blocks you assemble into a custom\n",
    "`braintools.cogtask.Task`:\n",
    "\n",
    "1. **Features** — `Feature` / `FeatureSet`, composition, and the `.i` slice.\n",
    "2. **Declarative phases** — `Fixation`, `Stimulus`, `Delay`, `Response`, ...\n",
    "3. **Encoders** — value specs like `one_hot`, `circular`, `von_mises`,\n",
    "   `cos_sin`, `gaussian`, `scalar`, `identity`, `ctx_value`.\n",
    "4. **Labels** — `label`, `match_label`, `comparison_label`.\n",
    "5. **Putting it together** — both the instance-based and the class-based\n",
    "   patterns.\n",
    "6. **Vector outputs** — using `output_mode='vector'` for continuous-report\n",
    "   tasks.\n",
    "7. **Branching with `If` / `Switch` / `While`**.\n",
    "8. **Custom encoders**.\n",
    "\n",
    "---\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bd2661c5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:11.755417Z",
     "iopub.status.busy": "2026-05-21T09:12:11.755163Z",
     "iopub.status.idle": "2026-05-21T09:12:14.586002Z",
     "shell.execute_reply": "2026-05-21T09:12:14.584847Z"
    }
   },
   "outputs": [],
   "source": [
    "import brainunit as u\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "\n",
    "brainstate.environ.set(dt=1.0 * u.ms)\n",
    "\n",
    "from braintools.cogtask import (\n",
    "    Task, Context, concat,\n",
    "    Phase, Sequence, Repeat, Parallel,\n",
    "    Fixation, Stimulus, Delay, Response, Cue, Blank,\n",
    "    Sample, Test, Recall, Match, Comparison,\n",
    "    If, Switch, While,\n",
    "    Feature, FeatureSet, CircleFeature,\n",
    "    one_hot, circular, von_mises, cos_sin, gaussian, scalar, identity, ctx_value,\n",
    "    label, match_label, comparison_label,\n",
    "    TruncExp, UniformDuration,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ee63be9",
   "metadata": {},
   "source": [
    "## 1. Features\n",
    "\n",
    "A `Feature` declares one logical input or output channel: a name plus a\n",
    "fixed dimensionality. Compose features with `+` (concatenate, immutable),\n",
    "`-` (remove by name), `|` (alias for `+`), and `*n` (named repetition).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7805c7f4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:14.588622Z",
     "iopub.status.busy": "2026-05-21T09:12:14.588176Z",
     "iopub.status.idle": "2026-05-21T09:12:14.593249Z",
     "shell.execute_reply": "2026-05-21T09:12:14.592383Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inputs: FeatureSet(names=['fixation', 'stimulus'], nums=[1, 8]) num = 9\n",
      "stim slice in inputs: slice(1, 9, None)\n",
      "choice slice in outputs: slice(1, 3, None)\n"
     ]
    }
   ],
   "source": [
    "fix    = Feature(1, 'fixation')\n",
    "stim   = Feature(8, 'stimulus')\n",
    "choice = Feature(2, 'choice')\n",
    "\n",
    "inputs  = fix + stim\n",
    "outputs = fix + choice\n",
    "print('inputs:', inputs, 'num =', inputs.num)\n",
    "print('stim slice in inputs:', inputs['stimulus'].i)\n",
    "print('choice slice in outputs:', outputs['choice'].i)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e81f68e",
   "metadata": {},
   "source": [
    "## 2. Declarative phases\n",
    "\n",
    "A `DeclarativePhase` describes its behavior with three dictionaries:\n",
    "\n",
    "- `inputs={feature_name: value_spec, ...}` — fill the input slice for one\n",
    "  feature. A value spec is either a constant (broadcast) or a callable\n",
    "  `f(ctx, feature) -> array`.\n",
    "- `outputs={...}` — same shape conventions; in *categorical* mode use the\n",
    "  reserved key `'label'`, in *vector* mode write per output feature.\n",
    "- `noise={feature_name: sigma}` — additive Gaussian noise scaled by\n",
    "  `1/sqrt(dt)` so its signal variance is invariant under changes of `dt`.\n",
    "\n",
    "Compose phases sequentially with `>>` or `concat([...])`, repeat with `*n`,\n",
    "and run two phases simultaneously with `|` (`Parallel`).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e37edc03",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:14.595712Z",
     "iopub.status.busy": "2026-05-21T09:12:14.595298Z",
     "iopub.status.idle": "2026-05-21T09:12:14.601516Z",
     "shell.execute_reply": "2026-05-21T09:12:14.600928Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequence(Fixation >> Stimulus >> Delay >> Response)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "phases = (\n",
    "    Fixation(100 * u.ms, inputs={'fixation': 1.0}, outputs={'label': 0})\n",
    "    >> Stimulus(500 * u.ms,\n",
    "                inputs={'fixation': 1.0,\n",
    "                        'stimulus': circular('direction', 'coherence')},\n",
    "                outputs={'label': 0})\n",
    "    >> Delay(300 * u.ms, inputs={'fixation': 1.0}, outputs={'label': 0})\n",
    "    >> Response(100 * u.ms,\n",
    "                inputs={'fixation': 0.0},\n",
    "                outputs={'label': lambda ctx, f: ctx['ground_truth'] + 1})\n",
    ")\n",
    "phases\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bb5e733",
   "metadata": {},
   "source": [
    "## 3. Encoders\n",
    "\n",
    "Encoders are helper factories that return callable value specs. The most\n",
    "commonly used ones:\n",
    "\n",
    "| encoder                        | trial state → activation                                                       |\n",
    "|--------------------------------|--------------------------------------------------------------------------------|\n",
    "| `one_hot(key)`                 | discrete class index → one-hot vector                                          |\n",
    "| `circular(key, coherence_key)` | direction (radians or index) → cosine tuning over preferred directions         |\n",
    "| `von_mises(key, …)`            | direction → von-Mises tuning curve in `[base_value, 1]`                        |\n",
    "| `cos_sin(key, num_dirs, …)`    | discrete direction → repeated `[cos θ, sin θ]` features                        |\n",
    "| `gaussian(key, sigma=…)`       | scalar value → Gaussian bumps over evenly-spaced centers                       |\n",
    "| `scalar(key, scale, offset)`   | scalar → broadcast `value * scale + offset` to all units                       |\n",
    "| `identity(key)`                | array stored in `ctx[key]` → written through unchanged                         |\n",
    "| `ctx_value(key, default=…)`    | raw value lookup; useful for time-varying inputs computed elsewhere            |\n",
    "\n",
    "Encoders read `ctx[key]` — typically a value set in `trial_init`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f46290eb",
   "metadata": {},
   "source": [
    "## 4. Label helpers\n",
    "\n",
    "In *categorical* output mode every phase fills `ctx.outputs[start:end]` with\n",
    "an integer label. The `label` helpers are convenient builders:\n",
    "\n",
    "- `label(value)` — static int, lookup from context (`str`), or arbitrary\n",
    "  callable `(ctx) -> int`.\n",
    "- `match_label(match_key)` — emit `1` on match trials, `2` otherwise.\n",
    "- `comparison_label(key)` — emit `1` if greater, `2` otherwise.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b466a24f",
   "metadata": {},
   "source": [
    "## 5. Putting it together — instance-based\n",
    "\n",
    "The smallest end-to-end recipe: define features, define phases, define\n",
    "`trial_init`, hand them to `Task`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a715a32f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:14.603676Z",
     "iopub.status.busy": "2026-05-21T09:12:14.603382Z",
     "iopub.status.idle": "2026-05-21T09:12:16.125079Z",
     "shell.execute_reply": "2026-05-21T09:12:16.124255Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X.shape = (1000, 9) Y.shape = (1000,)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'trial_index': 0,\n",
       " 'ground_truth': Array(0, dtype=int32),\n",
       " 'coherence': 51.2,\n",
       " 'direction': Array(0.5190151, dtype=float32),\n",
       " 'output_mode': 'categorical'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def trial_init(ctx):\n",
    "    ctx['ground_truth'] = ctx.rng.choice(2)              # jax scalar — OK under JIT\n",
    "    ctx['coherence']    = 51.2\n",
    "    ctx['direction']    = ctx.rng.uniform(0.0, 2 * jnp.pi)\n",
    "\n",
    "task = Task(\n",
    "    phases=phases,\n",
    "    input_features=inputs,\n",
    "    output_features=outputs,\n",
    "    trial_init=trial_init,\n",
    "    seed=0,\n",
    ")\n",
    "\n",
    "X, Y, info = task.sample_trial(0)\n",
    "print('X.shape =', X.shape, 'Y.shape =', Y.shape)\n",
    "info['trial_state']\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09c4af12",
   "metadata": {},
   "source": [
    "## 6. Putting it together — class-based\n",
    "\n",
    "For parameterized, reusable tasks, subclass `Task` and override\n",
    "`define_features`, `define_phases`, and `trial_init`. Pre-built tasks\n",
    "(`PerceptualDecisionMaking`, `DelayMatchSample`, ...) all follow this\n",
    "pattern.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "59c410ff",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:16.127748Z",
     "iopub.status.busy": "2026-05-21T09:12:16.127500Z",
     "iopub.status.idle": "2026-05-21T09:12:16.587794Z",
     "shell.execute_reply": "2026-05-21T09:12:16.586991Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X.shape = (1400, 64, 17) Y.shape = (1400, 64)\n"
     ]
    }
   ],
   "source": [
    "class MyDMS(Task):\n",
    "    t_fixation = 200 * u.ms\n",
    "    t_sample   = 400 * u.ms\n",
    "    t_delay    = 600 * u.ms\n",
    "    t_response = 200 * u.ms\n",
    "    num_stimuli = 8\n",
    "\n",
    "    def define_features(self):\n",
    "        fix    = Feature(1, 'fixation')\n",
    "        stim   = Feature(self.num_stimuli, 'stimulus')\n",
    "        choice = Feature(2, 'choice')   # match / non-match\n",
    "        return fix + stim, fix + choice\n",
    "\n",
    "    def define_phases(self):\n",
    "        return concat([\n",
    "            Fixation(self.t_fixation,\n",
    "                     inputs={'fixation': 1.0},\n",
    "                     outputs={'label': 0}),\n",
    "            Stimulus(self.t_sample,\n",
    "                     inputs={'fixation': 1.0,\n",
    "                             'stimulus': von_mises('sample_idx',\n",
    "                                                   num_dirs=self.num_stimuli)},\n",
    "                     outputs={'label': 0},\n",
    "                     name='sample'),\n",
    "            Delay(self.t_delay,\n",
    "                  inputs={'fixation': 1.0},\n",
    "                  outputs={'label': 0}),\n",
    "            Response(self.t_response,\n",
    "                     inputs={'fixation': 0.0,\n",
    "                             'stimulus': von_mises('test_idx',\n",
    "                                                   num_dirs=self.num_stimuli)},\n",
    "                     outputs={'label': label(lambda ctx: jnp.where(ctx['is_match'], 1, 2))},\n",
    "                     name='response'),\n",
    "        ])\n",
    "\n",
    "    def trial_init(self, ctx):\n",
    "        # All values stay as jax scalars / arrays so trial_init works under JIT+vmap.\n",
    "        ctx['sample_idx'] = ctx.rng.choice(self.num_stimuli)\n",
    "        ctx['is_match']   = ctx.rng.uniform() < 0.5\n",
    "        non_match_idx     = (ctx['sample_idx']\n",
    "                             + 1\n",
    "                             + ctx.rng.choice(self.num_stimuli - 1)) % self.num_stimuli\n",
    "        ctx['test_idx']   = jnp.where(ctx['is_match'], ctx['sample_idx'], non_match_idx)\n",
    "\n",
    "task = MyDMS(num_stimuli=16, seed=42)\n",
    "X, Y = task.batch_sample(64)\n",
    "print('X.shape =', X.shape, 'Y.shape =', Y.shape)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d01571a4",
   "metadata": {},
   "source": [
    "## 7. Vector outputs — continuous-report tasks\n",
    "\n",
    "Pass `output_mode='vector'` when the target is a population code rather\n",
    "than a class index. Each phase then writes into named output features.\n",
    "Features that should be silent during a phase are written explicitly with a\n",
    "zero spec (a constant `0.0` or a small callable that returns zeros).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8605d477",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:16.590265Z",
     "iopub.status.busy": "2026-05-21T09:12:16.589996Z",
     "iopub.status.idle": "2026-05-21T09:12:16.919172Z",
     "shell.execute_reply": "2026-05-21T09:12:16.918355Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X.shape = (1900, 8, 17)   Y.shape = (1900, 8, 17) (B, T, 1+16)\n"
     ]
    }
   ],
   "source": [
    "fix_in  = Feature(1, 'fixation')\n",
    "stim_in = Feature(16, 'stimulus')\n",
    "fix_out = Feature(1, 'fixation_out')\n",
    "dir_out = Feature(16, 'direction_out')\n",
    "\n",
    "def silent(ctx, feat):\n",
    "    return jnp.zeros((feat.num,))\n",
    "\n",
    "task = Task(\n",
    "    phases=concat([\n",
    "        Fixation(200 * u.ms,\n",
    "                 inputs={'fixation': 1.0},\n",
    "                 outputs={'fixation_out': 1.0, 'direction_out': silent}),\n",
    "        Stimulus(500 * u.ms,\n",
    "                 inputs={'fixation': 1.0,\n",
    "                         'stimulus': von_mises('sample_idx', num_dirs=16)},\n",
    "                 outputs={'fixation_out': 1.0, 'direction_out': silent}),\n",
    "        Delay(800 * u.ms,\n",
    "              inputs={'fixation': 1.0},\n",
    "              outputs={'fixation_out': 1.0, 'direction_out': silent}),\n",
    "        Response(400 * u.ms,\n",
    "                 inputs={'fixation': 0.0},\n",
    "                 outputs={'fixation_out': 0.0,\n",
    "                          'direction_out': von_mises('sample_idx', num_dirs=16)}),\n",
    "    ]),\n",
    "    input_features=fix_in + stim_in,\n",
    "    output_features=fix_out + dir_out,\n",
    "    output_mode='vector',\n",
    "    trial_init=lambda ctx: ctx.update(sample_idx=ctx.rng.choice(16)),\n",
    "    seed=0,\n",
    ")\n",
    "\n",
    "X, Y = task.batch_sample(8)\n",
    "print('X.shape =', X.shape, '  Y.shape =', Y.shape, '(B, T, 1+16)')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2136db7",
   "metadata": {},
   "source": [
    "## 8. Branching with `If` / `Switch`\n",
    "\n",
    "`If` selects between `then` / `else_` based on a predicate over the trial\n",
    "context; `Switch` dispatches over a dict of cases keyed by the selector's\n",
    "output. Both must be resolvable from values written in `trial_init`,\n",
    "because the framework needs to compute a deterministic total duration\n",
    "before the actual encoding pass.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "df9365a0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:16.921451Z",
     "iopub.status.busy": "2026-05-21T09:12:16.921195Z",
     "iopub.status.idle": "2026-05-21T09:12:16.926483Z",
     "shell.execute_reply": "2026-05-21T09:12:16.925807Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequence(Sample >> Delay >> Test >> If)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "branching = (\n",
    "    Sample(400 * u.ms,\n",
    "           inputs={'stimulus': von_mises('sample_idx', num_dirs=8)},\n",
    "           outputs={'label': 0})\n",
    "    >> Delay(600 * u.ms,\n",
    "             inputs={'fixation': 1.0},\n",
    "             outputs={'label': 0})\n",
    "    >> Test(400 * u.ms,\n",
    "            inputs={'stimulus': von_mises('test_idx', num_dirs=8)},\n",
    "            outputs={'label': 0})\n",
    "    >> If(\n",
    "        condition=lambda ctx: bool(ctx['is_match']),\n",
    "        then=Response(200 * u.ms,\n",
    "                      outputs={'label': 1}),\n",
    "        else_=Response(200 * u.ms,\n",
    "                       outputs={'label': 2}),\n",
    "    )\n",
    ")\n",
    "branching\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0b4e370",
   "metadata": {},
   "source": [
    "## 9. Looping with `While`\n",
    "\n",
    "`While(condition, body, max_iterations)` runs `body` until the predicate\n",
    "returns `False`, capped by `max_iterations`. The duration computation uses\n",
    "`max_iterations` as the upper bound, so the buffer is large enough for the\n",
    "worst case — see Tutorial 3 for the JIT/`vmap` implications.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f054899d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:16.928737Z",
     "iopub.status.busy": "2026-05-21T09:12:16.928410Z",
     "iopub.status.idle": "2026-05-21T09:12:16.934502Z",
     "shell.execute_reply": "2026-05-21T09:12:16.933825Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "While(body=Sample, max=20)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loop_demo = While(\n",
    "    condition=lambda ctx: ctx.get('evidence', 0.0) < ctx['threshold'],\n",
    "    body=Sample(50 * u.ms,\n",
    "                inputs={'stimulus': lambda ctx, f: ctx['pulse']},\n",
    "                outputs={'label': 0}),\n",
    "    max_iterations=20,\n",
    ")\n",
    "loop_demo\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbdfa0e4",
   "metadata": {},
   "source": [
    "## 10. Custom encoders\n",
    "\n",
    "The built-in encoders cover most use cases, but any callable with\n",
    "signature `f(ctx, feature) -> jnp.ndarray` is a valid value spec. The\n",
    "return value may be:\n",
    "\n",
    "- a scalar (broadcast to all timesteps and units),\n",
    "- a 1-D array of shape `(feature.num,)` (broadcast along time),\n",
    "- or a 2-D array of shape `(duration, feature.num)` (written directly).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6cdd2289",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-21T09:12:16.936281Z",
     "iopub.status.busy": "2026-05-21T09:12:16.936047Z",
     "iopub.status.idle": "2026-05-21T09:12:16.941943Z",
     "shell.execute_reply": "2026-05-21T09:12:16.941136Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeclarativePhase('Stimulus', inputs=['stimulus'])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def ramping_pulse(start_val, stop_val):\n",
    "    def encode(ctx, feature):\n",
    "        duration = ctx.phase_duration\n",
    "        ramp = jnp.linspace(start_val, stop_val, duration)\n",
    "        return jnp.broadcast_to(ramp[:, None], (duration, feature.num))\n",
    "    encode.__name__ = f'ramping_pulse({start_val}, {stop_val})'\n",
    "    return encode\n",
    "\n",
    "# Plug it in like any other encoder:\n",
    "ramp_phase = Stimulus(\n",
    "    500 * u.ms,\n",
    "    inputs={'stimulus': ramping_pulse(0.0, 1.0)},\n",
    "    outputs={'label': 0},\n",
    ")\n",
    "ramp_phase\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4379db47",
   "metadata": {},
   "source": [
    "## Where to next\n",
    "\n",
    "- **Tutorial 3 — Variable-length trial sequences** covers\n",
    "  `VariableDuration`, the packed-buffer + mask design, and how\n",
    "  `If` / `Switch` / `While` participate under `batch_sample`.\n",
    "- The API reference (`braintools.cogtask` under *API Reference*) lists every\n",
    "  pre-built task, encoder, phase, and utility with full parameter\n",
    "  documentation."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
