{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "73e08dc8",
   "metadata": {},
   "source": [
    "# Parameters, Transforms, and Regularization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cad03e8f",
   "metadata": {},
   "source": [
    "A `ParamState` is a bare trainable array. That is all most layers need. But some parameters\n",
    "carry *constraints* \u2014 a time constant must be positive, a mixing weight must lie in `[0, 1]`, a\n",
    "probability vector must sum to one \u2014 and some training objectives want a *penalty* on the\n",
    "parameters themselves.\n",
    "\n",
    "`brainstate.nn.Param` is a richer container that adds two orthogonal capabilities on top of a\n",
    "`ParamState`:\n",
    "\n",
    "- a **bijective transform** that maps an unconstrained array (what the optimizer sees) to a\n",
    "  constrained value (what the model uses);\n",
    "- a **regularization** term that contributes a penalty to the loss.\n",
    "\n",
    "This tutorial covers `Param`, its fixed counterpart `Const`, the transform catalog, and the\n",
    "regularization catalog, then ties them together in a single trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e2d75156",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:28.993405Z",
     "iopub.status.busy": "2026-05-30T16:20:28.993077Z",
     "iopub.status.idle": "2026-05-30T16:20:33.622381Z",
     "shell.execute_reply": "2026-05-30T16:20:33.621434Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "import braintools\n",
    "import brainstate.nn as nn\n",
    "\n",
    "brainstate.random.seed(0)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33976d9b",
   "metadata": {},
   "source": [
    "## From `ParamState` to `Param`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff981b23",
   "metadata": {},
   "source": [
    "Construct a `Param` from an initial value. Two attributes matter:\n",
    "\n",
    "- `param.value()` returns the **constrained** value \u2014 the number your model should use. It is a\n",
    "  *method*, because it may run a transform each time it is read.\n",
    "- `param.val` is the underlying `ParamState` holding the **unconstrained** array that the\n",
    "  optimizer updates.\n",
    "\n",
    "With the default `IdentityT` transform the two coincide."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "78aa8540",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:33.624933Z",
     "iopub.status.busy": "2026-05-30T16:20:33.624497Z",
     "iopub.status.idle": "2026-05-30T16:20:33.640454Z",
     "shell.execute_reply": "2026-05-30T16:20:33.639471Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "value() : [0.5 1.  2. ]\n",
      "val     : ParamState(\n",
      "  value=ShapedArray(float32[3])\n",
      ")\n",
      "trainable: True\n"
     ]
    }
   ],
   "source": [
    "w = nn.Param(jnp.array([0.5, 1.0, 2.0]))\n",
    "print('value() :', w.value())\n",
    "print('val     :', w.val)\n",
    "print('trainable:', w.fit)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd17a6ab",
   "metadata": {},
   "source": [
    "## Fixed values with `Const`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98ee8976",
   "metadata": {},
   "source": [
    "`Const` is a `Param` that is never trained (`fit=False`). It is *not* collected when you gather\n",
    "`ParamState`s, so optimizers and `grad` leave it alone \u2014 ideal for buffers, lookup tables, or\n",
    "hyperparameters you want to keep inside the module tree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6fa84203",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:33.643147Z",
     "iopub.status.busy": "2026-05-30T16:20:33.642885Z",
     "iopub.status.idle": "2026-05-30T16:20:33.681495Z",
     "shell.execute_reply": "2026-05-30T16:20:33.680778Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable ParamStates: [('weight', 'val')]\n",
      "weight.fit = True | scale.fit = False\n"
     ]
    }
   ],
   "source": [
    "class Mix(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.weight = nn.Param(jnp.array([1.0, -1.0]))   # trainable\n",
    "        self.scale = nn.Const(jnp.array(10.0))           # fixed\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return self.scale.value() * (x @ self.weight.value())\n",
    "\n",
    "m = Mix()\n",
    "print('trainable ParamStates:', list(m.states(brainstate.ParamState).keys()))\n",
    "print('weight.fit =', m.weight.fit, '| scale.fit =', m.scale.fit)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09b1a331",
   "metadata": {},
   "source": [
    "Only `weight` appears among the trainable states \u2014 `scale` is held constant."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7443942",
   "metadata": {},
   "source": [
    "## Constrained parameters: Transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a553990",
   "metadata": {},
   "source": [
    "A transform is a **bijector**: an invertible map between an unconstrained space (all of \u211d, where\n",
    "gradient descent is well behaved) and a constrained space (positives, an interval, a simplex).\n",
    "The optimizer updates the unconstrained array; `value()` applies the forward map to produce the\n",
    "constrained value; `set_value()` applies the inverse to store a constrained value back.\n",
    "\n",
    "`SoftplusT(lower=L)` maps \u211d to `(L, \u221e)`, so the parameter is guaranteed to stay above `L` no\n",
    "matter what the optimizer does."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "db17e890",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:33.683129Z",
     "iopub.status.busy": "2026-05-30T16:20:33.682982Z",
     "iopub.status.idle": "2026-05-30T16:20:37.132051Z",
     "shell.execute_reply": "2026-05-30T16:20:37.131246Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward (R -> positive): [0.126928  0.69314718 3.04858732]\n",
      "inverse (round-trip)   : [-1.9999996  0.         3.       ]\n"
     ]
    }
   ],
   "source": [
    "t = nn.SoftplusT(lower=0.0)\n",
    "raw = jnp.array([-2.0, 0.0, 3.0])\n",
    "constrained = t.forward(raw)\n",
    "print('forward (R -> positive):', constrained)\n",
    "print('inverse (round-trip)   :', t.inverse(constrained))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b13f79b",
   "metadata": {},
   "source": [
    "The catalog of built-in transforms covers the constraints that arise in practice:\n",
    "\n",
    "| Transform | Constrained space |\n",
    "| --- | --- |\n",
    "| `IdentityT` | \u211d (no constraint) |\n",
    "| `SoftplusT(lower)`, `ExpT`, `PositiveT` | strictly greater than a lower bound |\n",
    "| `SigmoidT(lower, upper)`, `ClipT(lower, upper)` | a bounded interval |\n",
    "| `TanhT`, `SoftsignT`, `ScaledSigmoidT` | a symmetric bounded range |\n",
    "| `SimplexT` | non-negative entries summing to one |\n",
    "| `UnitVectorT` | unit L2 norm |\n",
    "| `OrderedT` | monotonically increasing entries |\n",
    "| `AffineT(scale, shift)`, `PowerT`, `MaskedT` | reparameterisations |\n",
    "\n",
    "Compose several with `ChainT`; the transforms apply in order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b8b1d96d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:37.134771Z",
     "iopub.status.busy": "2026-05-30T16:20:37.134300Z",
     "iopub.status.idle": "2026-05-30T16:20:37.309593Z",
     "shell.execute_reply": "2026-05-30T16:20:37.309041Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity(1.3132616)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chained = nn.ChainT(nn.AffineT(scale=2.0, shift=1.0), nn.SoftplusT(lower=0.0))\n",
    "chained.forward(jnp.array(0.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa196060",
   "metadata": {},
   "source": [
    "Attach a transform when constructing a `Param`. Here a strictly positive time constant is\n",
    "learned: the stored value roams over \u211d, but `value()` is always positive."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7bba0059",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:37.311613Z",
     "iopub.status.busy": "2026-05-30T16:20:37.311392Z",
     "iopub.status.idle": "2026-05-30T16:20:37.592060Z",
     "shell.execute_reply": "2026-05-30T16:20:37.591125Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "constrained tau: 0.2543546259403229 (never drops below the 0.1 floor)\n"
     ]
    }
   ],
   "source": [
    "tau = nn.Param(jnp.array(2.0), t=nn.SoftplusT(lower=0.1))\n",
    "\n",
    "params = {'tau': tau.val}\n",
    "opt = braintools.optim.Adam(lr=1e-1)\n",
    "opt.register_trainable_weights(params)\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def step(target):\n",
    "    def loss_fn():\n",
    "        return (tau.value() - target) ** 2\n",
    "    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()\n",
    "    opt.update(grads)\n",
    "    return loss\n",
    "\n",
    "for _ in range(100):\n",
    "    step(0.05)   # push tau toward 0.05, below the floor of 0.1\n",
    "\n",
    "print('constrained tau:', float(tau.value()), '(never drops below the 0.1 floor)')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d9b00a5",
   "metadata": {},
   "source": [
    "## Penalising parameters: Regularization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef889e32",
   "metadata": {},
   "source": [
    "A regularization object contributes a scalar penalty derived from a parameter's value. Read it\n",
    "directly with `reg.loss(value)`, or \u2014 once attached to a `Param` \u2014 with `param.reg_loss()`,\n",
    "which applies the regularizer to that parameter's current value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "34096170",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:37.594085Z",
     "iopub.status.busy": "2026-05-30T16:20:37.593810Z",
     "iopub.status.idle": "2026-05-30T16:20:37.816330Z",
     "shell.execute_reply": "2026-05-30T16:20:37.815424Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "L1 penalty : 0.699999988079071\n",
      "L2 penalty : 2.5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "elastic-net via reg_loss(): 16.0\n"
     ]
    }
   ],
   "source": [
    "weights = jnp.array([3.0, -4.0])\n",
    "print('L1 penalty :', float(nn.L1Reg(0.1).loss(weights)))   # 0.1 * (|3| + |-4|)\n",
    "print('L2 penalty :', float(nn.L2Reg(0.1).loss(weights)))   # 0.1 * (3^2 + 4^2)\n",
    "\n",
    "p = nn.Param(weights, reg=nn.ElasticNetReg(l1_weight=1.0, l2_weight=1.0, alpha=0.5))\n",
    "print('elastic-net via reg_loss():', float(p.reg_loss()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26f2566e",
   "metadata": {},
   "source": [
    "The catalog spans classical penalties and Bayesian priors (a prior contributes its negative\n",
    "log-density as the penalty):\n",
    "\n",
    "| Family | Members |\n",
    "| --- | --- |\n",
    "| Sparsity / shrinkage | `L1Reg`, `L2Reg`, `ElasticNetReg`, `GroupLassoReg`, `MaxNormReg` |\n",
    "| Structure | `OrthogonalReg`, `SpectralNormReg`, `TotalVariationReg`, `EntropyReg` |\n",
    "| Bayesian priors | `GaussianReg`, `LogNormalReg`, `StudentTReg`, `CauchyReg`, `HorseshoeReg`, `SpikeAndSlabReg`, `DirichletReg`, \u2026 |\n",
    "\n",
    "Combine penalties with `ChainedReg`. Parameters carrying a prior can be re-drawn from it with\n",
    "`param.reset_to_prior()`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e072267c",
   "metadata": {},
   "source": [
    "## A constrained, regularized model end-to-end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34a7abda",
   "metadata": {},
   "source": [
    "This linear model uses an L2-penalised weight vector and a strictly positive output `scale`. The\n",
    "training objective is the data loss plus the summed regularization penalties, collected by\n",
    "walking the module tree with `model.nodes(nn.Param)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ea067cc1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:37.818602Z",
     "iopub.status.busy": "2026-05-30T16:20:37.818352Z",
     "iopub.status.idle": "2026-05-30T16:20:38.293995Z",
     "shell.execute_reply": "2026-05-30T16:20:38.293109Z"
    }
   },
   "outputs": [],
   "source": [
    "class RegLinear(nn.Module):\n",
    "    def __init__(self, din):\n",
    "        super().__init__()\n",
    "        self.w = nn.Param(brainstate.random.randn(din) * 0.1, reg=nn.L2Reg(1e-2))\n",
    "        self.b = nn.Param(jnp.zeros(()))\n",
    "        self.scale = nn.Param(jnp.array(1.0), t=nn.SoftplusT(lower=1e-3))\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return self.scale.value() * (x @ self.w.value() + self.b.value())\n",
    "\n",
    "    def reg_penalty(self):\n",
    "        return sum(p.reg_loss() for p in self.nodes(nn.Param).values())\n",
    "\n",
    "model = RegLinear(4)\n",
    "x = brainstate.random.randn(128, 4)\n",
    "y = x @ jnp.array([1.0, -2.0, 0.5, 3.0]) + 0.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "196dbd50",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:38.296504Z",
     "iopub.status.busy": "2026-05-30T16:20:38.296279Z",
     "iopub.status.idle": "2026-05-30T16:20:38.600207Z",
     "shell.execute_reply": "2026-05-30T16:20:38.599227Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "final loss : 0.0378\n",
      "scale > 0  : 1.9410\n",
      "reg penalty: 0.037739\n"
     ]
    }
   ],
   "source": [
    "params = model.states(brainstate.ParamState)\n",
    "opt = braintools.optim.Adam(lr=5e-2)\n",
    "opt.register_trainable_weights(params)\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def train_step():\n",
    "    def loss_fn():\n",
    "        data_loss = jnp.mean((model(x) - y) ** 2)\n",
    "        return data_loss + model.reg_penalty()\n",
    "    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()\n",
    "    opt.update(grads)\n",
    "    return loss\n",
    "\n",
    "for epoch in range(300):\n",
    "    loss = train_step()\n",
    "\n",
    "print(f'final loss : {float(loss):.4f}')\n",
    "print(f'scale > 0  : {float(model.scale.value()):.4f}')\n",
    "print(f'reg penalty: {float(model.reg_penalty()):.6f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "417dd9ce",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- **`Param`** wraps a `ParamState` with an optional transform and regularizer; read the usable\n",
    "  value with `value()` and the trainable array with `val`.\n",
    "- **`Const`** is a non-trainable `Param`, excluded from `ParamState` collection.\n",
    "- **Transforms** are bijectors that keep a parameter inside its constrained space while the\n",
    "  optimizer works in unconstrained \u211d. Compose them with `ChainT`.\n",
    "- **Regularization** adds a penalty via `reg.loss(value)` or `param.reg_loss()`; sum penalties\n",
    "  across a model by iterating `model.nodes(nn.Param)`.\n",
    "\n",
    "### See also\n",
    "\n",
    "- [Training and metrics](07_training_and_metrics.ipynb) \u2014 the optimizer and loss machinery used here.\n",
    "- [Constrain and regularize parameters](../../how_to/constrain_and_regularize_parameters.ipynb) \u2014 a focused how-to recipe.\n",
    "- [The parameter model](../../concepts/the_parameter_model.md) \u2014 the design rationale behind `Param`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}