{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f8c71012dc176298",
   "metadata": {},
   "source": [
    "# Mixin System\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb4657616a309500",
   "metadata": {},
   "source": [
    "This tutorial explains the mixin utilities that ship with `brainstate`. After working through the examples you will:\n",
    "\n",
    "- Understand what a mixin is and when to use one.\n",
    "- Reuse behaviors by inheriting from `brainstate.mixin.Mixin`.\n",
    "- Capture reusable constructor presets with `ParamDesc` and `ParamDescriber`.\n",
    "- Express rich type expectations with `JointTypes` and `OneOfTypes`.\n",
    "- Control runtime behaviour with the built-in mode mixins such as `Training`, `Batching`, and `JointMode`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c82deb74753ef79c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:54:27.764527Z",
     "iopub.status.busy": "2026-05-30T16:54:27.764308Z",
     "iopub.status.idle": "2026-05-30T16:54:29.927994Z",
     "shell.execute_reply": "2026-05-30T16:54:29.926993Z"
    }
   },
   "outputs": [],
   "source": [
    "import datetime\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "from brainstate import mixin\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f438237a5a9d8838",
   "metadata": {},
   "source": [
    "## What is a mixin?\n",
    "\n",
    "A *mixin* is a lightweight class that contributes behaviour (methods or attributes) without forcing a rigid inheritance hierarchy.\n",
    "In BrainState every mixin inherits from `brainstate.mixin.Mixin`, signalling that the class\n",
    "provides optional behaviour and should not define its own `__init__`.\n",
    "Mixins are usually paired with core components such as `brainstate.nn.Module` to keep reusable code close to the consumer.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "301aef66cee37807",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:54:29.930454Z",
     "iopub.status.busy": "2026-05-30T16:54:29.930036Z",
     "iopub.status.idle": "2026-05-30T16:54:29.935316Z",
     "shell.execute_reply": "2026-05-30T16:54:29.934654Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LOG 00:54:29] Accumulator: updated running total to 1.25\n",
      "[LOG 00:54:29] Accumulator: updated running total to 4.00\n"
     ]
    }
   ],
   "source": [
    "class LoggingMixin(mixin.Mixin):\n",
    "    \"\"\"Attach timestamped logging to any class without touching its constructor.\"\"\"\n",
    "\n",
    "    def log(self, message: str) -> None:\n",
    "        stamp = datetime.datetime.now().strftime('%H:%M:%S')\n",
    "        print(f'[LOG {stamp}] {self.__class__.__name__}: {message}')\n",
    "\n",
    "\n",
    "class Accumulator(brainstate.nn.Module, LoggingMixin):\n",
    "    \"\"\"Simple module that reuses the logging helper.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.total = 0.0\n",
    "\n",
    "    def add(self, value):\n",
    "        self.total += float(value)\n",
    "        self.log(f'updated running total to {self.total:.2f}')\n",
    "        return self.total\n",
    "\n",
    "\n",
    "acc = Accumulator()\n",
    "_ = acc.add(1.25)\n",
    "_ = acc.add(2.75)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "727c34e7e4c1c893",
   "metadata": {},
   "source": [
    "### Design tips\n",
    "\n",
    "- A mixin should only provide behaviour; avoid introducing new required constructor arguments.\n",
    "- Keep mixins focused. Several small mixins compose better than a single, opinionated base class.\n",
    "- Document expectations about host classes (e.g. attributes a mixin reads or writes).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6fb03d4aac88c0e",
   "metadata": {},
   "source": [
    "## Parameter descriptors with `ParamDesc`\n",
    "\n",
    "`ParamDesc` helps you capture reusable constructor presets.\n",
    "The `desc()` class method stores the provided arguments inside a `ParamDescriber`, which you can later call\n",
    "to instantiate new objects while still overriding any argument on demand.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ac25ce10c9b4db35",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:54:29.937815Z",
     "iopub.status.busy": "2026-05-30T16:54:29.937509Z",
     "iopub.status.idle": "2026-05-30T16:54:29.942457Z",
     "shell.execute_reply": "2026-05-30T16:54:29.941635Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gelu dense 256 → 128\n",
      "relu dense 256 → 128\n",
      "relu dense 128 → 64\n"
     ]
    }
   ],
   "source": [
    "class DenseBlock(mixin.ParamDesc):\n",
    "    \"\"\"Toy layer that records its configuration for inspection.\"\"\"\n",
    "\n",
    "    def __init__(self, in_features: int, out_features: int, *, activation: str = 'relu'):\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.activation = activation\n",
    "\n",
    "    def summary(self) -> str:\n",
    "        return f'{self.activation} dense {self.in_features} → {self.out_features}'\n",
    "\n",
    "\n",
    "encoder_block = DenseBlock.desc(256, 128, activation='gelu')\n",
    "decoder_block = DenseBlock.desc(128, 64, activation='relu')\n",
    "\n",
    "print(encoder_block().summary())\n",
    "print(encoder_block(activation='relu').summary())  # override at call time\n",
    "print(decoder_block().summary())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49989f448b045d6c",
   "metadata": {},
   "source": [
    "`ParamDesc` stores descriptors in a hashable structure. This plays nicely with caching systems because\n",
    "`descriptor.identifier` is safe to use as a dictionary key.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "481e8f9676b55783",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:54:29.944591Z",
     "iopub.status.busy": "2026-05-30T16:54:29.944415Z",
     "iopub.status.idle": "2026-05-30T16:54:29.948380Z",
     "shell.execute_reply": "2026-05-30T16:54:29.947760Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(<class '__main__.DenseBlock'>, (256, 128), {'activation': 'gelu'})\n"
     ]
    }
   ],
   "source": [
    "print(encoder_block.identifier)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3507b43c46362102",
   "metadata": {},
   "source": [
    "### Using `ParamDescriber` directly\n",
    "\n",
    "If you want to describe classes that do not inherit from `ParamDesc`, you can work with\n",
    "`ParamDescriber` manually.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "67595e939894f441",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:54:29.950333Z",
     "iopub.status.busy": "2026-05-30T16:54:29.950089Z",
     "iopub.status.idle": "2026-05-30T16:54:29.954578Z",
     "shell.execute_reply": "2026-05-30T16:54:29.953941Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OptimConfig(lr=0.001, beta1=0.95, beta2=0.999)\n",
      "OptimConfig(lr=0.0005, beta1=0.95, beta2=0.999)\n"
     ]
    }
   ],
   "source": [
    "@dataclass\n",
    "class OptimConfig:\n",
    "    lr: float\n",
    "    beta1: float = 0.9\n",
    "    beta2: float = 0.999\n",
    "\n",
    "\n",
    "adam_template = mixin.ParamDescriber(OptimConfig, lr=1e-3, beta1=0.95)\n",
    "opt_a = adam_template()\n",
    "opt_b = adam_template(lr=5e-4)  # override a keyword\n",
    "\n",
    "print(opt_a)\n",
    "print(opt_b)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5fb6921991d41a8e",
   "metadata": {},
   "source": [
    "## Type combinators: `JointTypes` and `OneOfTypes`\n",
    "\n",
    "BrainState ships two helpers that make intent explicit when a value must satisfy multiple interfaces\n",
    "or just one of several options:\n",
    "\n",
    "- `JointTypes[A, B, ...]` behaves like an intersection — an instance must satisfy *all* listed types.\n",
    "- `OneOfTypes[A, B, ...]` behaves like a union — an instance may satisfy *any* listed type.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5bc20e4a339a967b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:54:29.956439Z",
     "iopub.status.busy": "2026-05-30T16:54:29.956275Z",
     "iopub.status.idle": "2026-05-30T16:54:29.960617Z",
     "shell.execute_reply": "2026-05-30T16:54:29.959681Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True True\n"
     ]
    }
   ],
   "source": [
    "class Persistable:\n",
    "    def save(self):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "class Visualisable:\n",
    "    def plot(self):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "class Report(Persistable, Visualisable):\n",
    "    def save(self):\n",
    "        return 'saved to disk'\n",
    "\n",
    "    def plot(self):\n",
    "        return 'rendering preview'\n",
    "\n",
    "\n",
    "FullFeatureType = mixin.JointTypes[Persistable, Visualisable]\n",
    "OptionalNumber = mixin.OneOfTypes[int, float, type(None)]\n",
    "\n",
    "report = Report()\n",
    "print(isinstance(report, FullFeatureType))\n",
    "print(isinstance(3.14, OptionalNumber), isinstance(None, OptionalNumber))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97f844bf366bc35a",
   "metadata": {},
   "source": [
    "## Mode mixins for runtime behaviour\n",
    "\n",
    "Mode objects capture the *context* in which computation happens.\n",
    "The base `Mode` class is lightweight, and the built-ins `Training`, `Batching`, and `JointMode` cover\n",
    "common runtime switches.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6bb2bd78f1823b7b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:54:29.962732Z",
     "iopub.status.busy": "2026-05-30T16:54:29.962517Z",
     "iopub.status.idle": "2026-05-30T16:54:30.109608Z",
     "shell.execute_reply": "2026-05-30T16:54:30.108435Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "default [0. 1. 2. 3.]\n",
      "training [0.1 1.1 2.1 3.1]\n",
      "joint [0.6 2.6]\n",
      "joint exposes batch size: 2\n"
     ]
    }
   ],
   "source": [
    "class ToyPipeline:\n",
    "    \"\"\"A tiny module that responds to different mode configurations.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        self.mode: mixin.Mode = mixin.Mode()\n",
    "\n",
    "    def set_mode(self, *modes: mixin.Mode):\n",
    "        if not modes:\n",
    "            self.mode = mixin.Mode()\n",
    "        elif len(modes) == 1:\n",
    "            self.mode = modes[0]\n",
    "        else:\n",
    "            self.mode = mixin.JointMode(*modes)\n",
    "\n",
    "    def forward(self, values):\n",
    "        x = jnp.asarray(values, dtype=jnp.float32)\n",
    "        if self.mode.has(mixin.Training):\n",
    "            x = x + 0.1  # emulate noise or dropout\n",
    "        if self.mode.has(mixin.Batching):\n",
    "            batch = self.mode.batch_size\n",
    "            x = x.reshape((batch, -1)).mean(axis=1)\n",
    "        return x\n",
    "\n",
    "\n",
    "pipeline = ToyPipeline()\n",
    "print('default', pipeline.forward(jnp.arange(4.0)))\n",
    "\n",
    "pipeline.set_mode(mixin.Training())\n",
    "print('training', pipeline.forward(jnp.arange(4.0)))\n",
    "\n",
    "pipeline.set_mode(mixin.Training(), mixin.Batching(batch_size=2))\n",
    "print('joint', pipeline.forward(jnp.arange(4.0)))\n",
    "print('joint exposes batch size:', pipeline.mode.batch_size)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fafcd5163ba9f7dc",
   "metadata": {},
   "source": [
    "The joint mode exposes the attributes of its members, so accessing `pipeline.mode.batch_size` works even\n",
    "though the current mode is a `JointMode` instance.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "727dfb44365e34c3",
   "metadata": {},
   "source": [
    "## Putting it together\n",
    "\n",
    "When you combine these mixin tools you can:\n",
    "\n",
    "1. Add reusable behaviour (logging, validation, metrics) without disturbing core module hierarchies.\n",
    "2. Parameterise component templates and reuse them safely through descriptors.\n",
    "3. Encode clear expectations about inputs or collaborators via `JointTypes`/`OneOfTypes`.\n",
    "4. Toggle runtime semantics with mode objects instead of ad-hoc boolean flags.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81ba9bac76c80409",
   "metadata": {},
   "source": [
    "### Next steps\n",
    "\n",
    "- Audit your own modules for behaviours that could live in a mixin.\n",
    "- Wrap frequently reused constructor arguments with `ParamDesc`.\n",
    "- Adopt mode objects in your training scripts to centralise feature flags (e.g. evaluation vs training).\n",
    "- Explore `brainstate.mixin.not_implemented` to clearly mark unsupported operations.\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
