{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Mixin System\n",
   "id": "f8c71012dc176298"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "This tutorial explains the mixin utilities that ship with `brainstate`. After working through the examples you will:\n",
    "\n",
    "- Understand what a mixin is and when to use one.\n",
    "- Reuse behaviors by inheriting from `brainstate.mixin.Mixin`.\n",
    "- Capture reusable constructor presets with `ParamDesc` and `ParamDescriber`.\n",
    "- Express rich type expectations with `JointTypes` and `OneOfTypes`.\n",
    "- Control runtime behaviour with the built-in mode mixins such as `Training`, `Batching`, and `JointMode`.\n"
   ],
   "id": "fb4657616a309500"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import datetime\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "from brainstate import mixin\n"
   ],
   "id": "c82deb74753ef79c"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## What is a mixin?\n",
    "\n",
    "A *mixin* is a lightweight class that contributes behaviour (methods or attributes) without forcing a rigid inheritance hierarchy.\n",
    "In BrainState every mixin inherits from `brainstate.mixin.Mixin`, signalling that the class\n",
    "provides optional behaviour and should not define its own `__init__`.\n",
    "Mixins are usually paired with core components such as `brainstate.nn.Module` to keep reusable code close to the consumer.\n"
   ],
   "id": "f438237a5a9d8838"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "class LoggingMixin(mixin.Mixin):\n",
    "    \"\"\"Attach timestamped logging to any class without touching its constructor.\"\"\"\n",
    "\n",
    "    def log(self, message: str) -> None:\n",
    "        stamp = datetime.datetime.now().strftime('%H:%M:%S')\n",
    "        print(f'[LOG {stamp}] {self.__class__.__name__}: {message}')\n",
    "\n",
    "\n",
    "class Accumulator(brainstate.nn.Module, LoggingMixin):\n",
    "    \"\"\"Simple module that reuses the logging helper.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.total = 0.0\n",
    "\n",
    "    def add(self, value):\n",
    "        self.total += float(value)\n",
    "        self.log(f'updated running total to {self.total:.2f}')\n",
    "        return self.total\n",
    "\n",
    "\n",
    "acc = Accumulator()\n",
    "_ = acc.add(1.25)\n",
    "_ = acc.add(2.75)\n"
   ],
   "id": "301aef66cee37807"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Design tips\n",
    "\n",
    "- A mixin should only provide behaviour; avoid introducing new required constructor arguments.\n",
    "- Keep mixins focused. Several small mixins compose better than a single, opinionated base class.\n",
    "- Document expectations about host classes (e.g. attributes a mixin reads or writes).\n"
   ],
   "id": "727c34e7e4c1c893"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Parameter descriptors with `ParamDesc`\n",
    "\n",
    "`ParamDesc` helps you capture reusable constructor presets.\n",
    "The `desc()` class method stores the provided arguments inside a `ParamDescriber`, which you can later call\n",
    "to instantiate new objects while still overriding any argument on demand.\n"
   ],
   "id": "a6fb03d4aac88c0e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "class DenseBlock(mixin.ParamDesc):\n",
    "    \"\"\"Toy layer that records its configuration for inspection.\"\"\"\n",
    "\n",
    "    def __init__(self, in_features: int, out_features: int, *, activation: str = 'relu'):\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.activation = activation\n",
    "\n",
    "    def summary(self) -> str:\n",
    "        return f'{self.activation} dense {self.in_features} → {self.out_features}'\n",
    "\n",
    "\n",
    "encoder_block = DenseBlock.desc(256, 128, activation='gelu')\n",
    "decoder_block = DenseBlock.desc(128, 64, activation='relu')\n",
    "\n",
    "print(encoder_block().summary())\n",
    "print(encoder_block(activation='relu').summary())  # override at call time\n",
    "print(decoder_block().summary())\n"
   ],
   "id": "ac25ce10c9b4db35"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "`ParamDesc` stores descriptors in a hashable structure. This plays nicely with caching systems because\n",
    "`descriptor.identifier` is safe to use as a dictionary key.\n"
   ],
   "id": "49989f448b045d6c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "print(encoder_block.identifier)\n",
   "id": "481e8f9676b55783"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Using `ParamDescriber` directly\n",
    "\n",
    "If you want to describe classes that do not inherit from `ParamDesc`, you can work with\n",
    "`ParamDescriber` manually.\n"
   ],
   "id": "3507b43c46362102"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "@dataclass\n",
    "class OptimConfig:\n",
    "    lr: float\n",
    "    beta1: float = 0.9\n",
    "    beta2: float = 0.999\n",
    "\n",
    "\n",
    "adam_template = mixin.ParamDescriber(OptimConfig, lr=1e-3, beta1=0.95)\n",
    "opt_a = adam_template()\n",
    "opt_b = adam_template(lr=5e-4)  # override a keyword\n",
    "\n",
    "print(opt_a)\n",
    "print(opt_b)\n"
   ],
   "id": "67595e939894f441"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Type combinators: `JointTypes` and `OneOfTypes`\n",
    "\n",
    "BrainState ships two helpers that make intent explicit when a value must satisfy multiple interfaces\n",
    "or just one of several options:\n",
    "\n",
    "- `JointTypes[A, B, ...]` behaves like an intersection — an instance must satisfy *all* listed types.\n",
    "- `OneOfTypes[A, B, ...]` behaves like a union — an instance may satisfy *any* listed type.\n"
   ],
   "id": "5fb6921991d41a8e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "class Persistable:\n",
    "    def save(self):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "class Visualisable:\n",
    "    def plot(self):\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "class Report(Persistable, Visualisable):\n",
    "    def save(self):\n",
    "        return 'saved to disk'\n",
    "\n",
    "    def plot(self):\n",
    "        return 'rendering preview'\n",
    "\n",
    "\n",
    "FullFeatureType = mixin.JointTypes[Persistable, Visualisable]\n",
    "OptionalNumber = mixin.OneOfTypes[int, float, type(None)]\n",
    "\n",
    "report = Report()\n",
    "print(isinstance(report, FullFeatureType))\n",
    "print(isinstance(3.14, OptionalNumber), isinstance(None, OptionalNumber))\n"
   ],
   "id": "5bc20e4a339a967b"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Mode mixins for runtime behaviour\n",
    "\n",
    "Mode objects capture the *context* in which computation happens.\n",
    "The base `Mode` class is lightweight, and the built-ins `Training`, `Batching`, and `JointMode` cover\n",
    "common runtime switches.\n"
   ],
   "id": "97f844bf366bc35a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "class ToyPipeline:\n",
    "    \"\"\"A tiny module that responds to different mode configurations.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        self.mode: mixin.Mode = mixin.Mode()\n",
    "\n",
    "    def set_mode(self, *modes: mixin.Mode):\n",
    "        if not modes:\n",
    "            self.mode = mixin.Mode()\n",
    "        elif len(modes) == 1:\n",
    "            self.mode = modes[0]\n",
    "        else:\n",
    "            self.mode = mixin.JointMode(*modes)\n",
    "\n",
    "    def forward(self, values):\n",
    "        x = jnp.asarray(values, dtype=jnp.float32)\n",
    "        if self.mode.has(mixin.Training):\n",
    "            x = x + 0.1  # emulate noise or dropout\n",
    "        if self.mode.has(mixin.Batching):\n",
    "            batch = self.mode.batch_size\n",
    "            x = x.reshape((batch, -1)).mean(axis=1)\n",
    "        return x\n",
    "\n",
    "\n",
    "pipeline = ToyPipeline()\n",
    "print('default', pipeline.forward(jnp.arange(4.0)))\n",
    "\n",
    "pipeline.set_mode(mixin.Training())\n",
    "print('training', pipeline.forward(jnp.arange(4.0)))\n",
    "\n",
    "pipeline.set_mode(mixin.Training(), mixin.Batching(batch_size=2))\n",
    "print('joint', pipeline.forward(jnp.arange(4.0)))\n",
    "print('joint exposes batch size:', pipeline.mode.batch_size)\n"
   ],
   "id": "6bb2bd78f1823b7b"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "The joint mode exposes the attributes of its members, so accessing `pipeline.mode.batch_size` works even\n",
    "though the current mode is a `JointMode` instance.\n"
   ],
   "id": "fafcd5163ba9f7dc"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Putting it together\n",
    "\n",
    "When you combine these mixin tools you can:\n",
    "\n",
    "1. Add reusable behaviour (logging, validation, metrics) without disturbing core module hierarchies.\n",
    "2. Parameterise component templates and reuse them safely through descriptors.\n",
    "3. Encode clear expectations about inputs or collaborators via `JointTypes`/`OneOfTypes`.\n",
    "4. Toggle runtime semantics with mode objects instead of ad-hoc boolean flags.\n"
   ],
   "id": "727dfb44365e34c3"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Next steps\n",
    "\n",
    "- Audit your own modules for behaviours that could live in a mixin.\n",
    "- Wrap frequently reused constructor arguments with `ParamDesc`.\n",
    "- Adopt mode objects in your training scripts to centralise feature flags (e.g. evaluation vs training).\n",
    "- Explore `brainstate.mixin.not_implemented` to clearly mark unsupported operations.\n"
   ],
   "id": "81ba9bac76c80409"
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}
