{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "64ea46b8",
   "metadata": {},
   "source": [
    "# Constrain and Regularize Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cf10832",
   "metadata": {},
   "source": [
    "A bare `brainstate.ParamState` holds an unconstrained array. Many models need more: a rate that\n",
    "must stay positive, a mixing weight bounded to `[0, 1]`, a probability vector that sums to one,\n",
    "or an L2 penalty pulling weights toward zero. `brainstate.nn.Param` adds two declarative\n",
    "facilities on top of `ParamState`:\n",
    "\n",
    "- a **constraint transform** (`t=`) — the optimizer updates the parameter in an *unconstrained*\n",
    "  space, while `.value()` returns it mapped into the valid domain. Gradients flow cleanly\n",
    "  because the mapping is a smooth bijection, not a hard clip.\n",
    "- a **regularization** prior (`reg=`) — `.reg_loss()` returns a penalty you add to the loss.\n",
    "\n",
    "`brainstate.nn.Const` is the companion for values that should *not* be trained."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0e597b6a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:14:13.056103Z",
     "iopub.status.busy": "2026-05-30T17:14:13.055019Z",
     "iopub.status.idle": "2026-05-30T17:14:15.212465Z",
     "shell.execute_reply": "2026-05-30T17:14:15.211525Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "import brainunit as u\n",
    "\n",
    "import brainstate\n",
    "from brainstate import nn\n",
    "\n",
    "brainstate.random.seed(0)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "547ea40e",
   "metadata": {},
   "source": [
    "`.value()` returns a `brainunit.Quantity` (dimensionless in these examples) that behaves like a\n",
    "JAX array in arithmetic. `u.get_magnitude(...)` extracts the plain array when we want to print\n",
    "or compare it."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5abc0b0",
   "metadata": {},
   "source": [
    "## A positive-only parameter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "737dba99",
   "metadata": {},
   "source": [
    "`SoftplusT(lower)` maps the whole real line onto `(lower, ∞)`. We store the parameter\n",
    "unconstrained and read it back through the transform, so it is positive no matter what value the\n",
    "optimizer lands on — even a large negative one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9c699718",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:14:15.215002Z",
     "iopub.status.busy": "2026-05-30T17:14:15.214581Z",
     "iopub.status.idle": "2026-05-30T17:14:17.908923Z",
     "shell.execute_reply": "2026-05-30T17:14:17.907932Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initial value : 0.4999999701976776\n",
      "after a large negative update: 4.5398901420412585e-05\n"
     ]
    }
   ],
   "source": [
    "rate = nn.Param(jnp.array(0.5), t=nn.SoftplusT(lower=0.0))\n",
    "print('initial value :', float(u.get_magnitude(rate.value())))\n",
    "\n",
    "# Simulate an aggressive optimizer step that drives the *unconstrained* value negative.\n",
    "rate.val.value = jnp.array(-10.0)\n",
    "print('after a large negative update:', float(u.get_magnitude(rate.value())))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d0c8065",
   "metadata": {},
   "source": [
    "`rate.val` is the underlying `ParamState` the optimizer mutates; `rate.value()` is the\n",
    "constrained view the model should use in its forward pass. The constrained value stays just\n",
    "above the `lower` bound instead of going negative."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "495f76e2",
   "metadata": {},
   "source": [
    "## Bounding to an interval"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fa28076",
   "metadata": {},
   "source": [
    "`SigmoidT(lower, upper)` constrains a parameter to an open interval. The transform catalogue\n",
    "covers the common cases — a few of the most useful:\n",
    "\n",
    "| Transform | Domain of `.value()` |\n",
    "|---|---|\n",
    "| `SoftplusT(lower)` / `ExpT(lower)` | `(lower, ∞)` — positive quantities |\n",
    "| `SigmoidT(lower, upper)` | `(lower, upper)` — bounded scalars |\n",
    "| `SimplexT()` | non-negative vector summing to 1 — probabilities |\n",
    "| `AffineT(scale, shift)` | linear rescale |\n",
    "| `ChainT(t1, t2, ...)` | compose transforms |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0c52c695",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:14:17.911615Z",
     "iopub.status.busy": "2026-05-30T17:14:17.911149Z",
     "iopub.status.idle": "2026-05-30T17:14:18.069098Z",
     "shell.execute_reply": "2026-05-30T17:14:18.068182Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "unconstrained     0.0 -> value 0.5000\n",
      "unconstrained   100.0 -> value 1.0000\n",
      "unconstrained  -100.0 -> value 0.0000\n"
     ]
    }
   ],
   "source": [
    "mix = nn.Param(jnp.array(0.5), t=nn.SigmoidT(lower=0.0, upper=1.0))\n",
    "\n",
    "for unconstrained in (0.0, 100.0, -100.0):\n",
    "    mix.val.value = jnp.array(unconstrained)\n",
    "    print(f'unconstrained {unconstrained:>7} -> value {float(u.get_magnitude(mix.value())):.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8188bf84",
   "metadata": {},
   "source": [
    "The midpoint of the unconstrained axis maps to the middle of the interval, and large magnitudes\n",
    "saturate toward the bounds without ever crossing them."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "151f6d97",
   "metadata": {},
   "source": [
    "A `SimplexT` parameter is handy for a learned categorical distribution: whatever the optimizer\n",
    "does to the unconstrained values, `.value()` is always a valid probability vector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8525e0d6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:14:18.071645Z",
     "iopub.status.busy": "2026-05-30T17:14:18.071329Z",
     "iopub.status.idle": "2026-05-30T17:14:18.605636Z",
     "shell.execute_reply": "2026-05-30T17:14:18.604728Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "probabilities [0.5, 0.25, 0.125, 0.125] sum 1.0\n",
      "probabilities [0.8808, 0.0321, 0.0542, 0.0329] sum 1.0\n"
     ]
    }
   ],
   "source": [
    "probs = nn.Param(jnp.zeros(3), t=nn.SimplexT())\n",
    "\n",
    "for unconstrained in ([0.0, 0.0, 0.0], [2.0, -1.0, 0.5]):\n",
    "    probs.val.value = jnp.array(unconstrained)\n",
    "    p = u.get_magnitude(probs.value())\n",
    "    print('probabilities', [round(float(x), 4) for x in p], 'sum', round(float(p.sum()), 6))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1414cbb8",
   "metadata": {},
   "source": [
    "## Adding a regularization prior"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b12c4e12",
   "metadata": {},
   "source": [
    "Pass `reg=` to attach a penalty. `.reg_loss()` returns the scalar contribution for that\n",
    "parameter, which you add to the data loss. The built-in choices include `L1Reg` (sparsity),\n",
    "`L2Reg` (weight decay), and `ElasticNetReg` (a blend)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "caa65871",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:14:18.607232Z",
     "iopub.status.busy": "2026-05-30T17:14:18.607079Z",
     "iopub.status.idle": "2026-05-30T17:14:18.760275Z",
     "shell.execute_reply": "2026-05-30T17:14:18.759705Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "L2 penalty: 2.5\n",
      "L1 penalty: 0.699999988079071\n"
     ]
    }
   ],
   "source": [
    "weights = nn.Param(jnp.array([3.0, -4.0]), reg=nn.L2Reg(weight=0.1))\n",
    "print('L2 penalty:', float(u.get_magnitude(weights.reg_loss())))   # 0.1 * sum(w**2)\n",
    "\n",
    "sparse = nn.Param(jnp.array([3.0, -4.0]), reg=nn.L1Reg(weight=0.1))\n",
    "print('L1 penalty:', float(u.get_magnitude(sparse.reg_loss())))    # 0.1 * sum(|w|)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34c16558",
   "metadata": {},
   "source": [
    "## Marking a value constant with `Const`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f11170b1",
   "metadata": {},
   "source": [
    "`Const` wraps a value that participates in the forward pass but is never trained. It is\n",
    "deliberately excluded from the `ParamState` collection, so optimizers and `grad` skip it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "21c17270",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:14:18.762467Z",
     "iopub.status.busy": "2026-05-30T17:14:18.762236Z",
     "iopub.status.idle": "2026-05-30T17:14:18.768050Z",
     "shell.execute_reply": "2026-05-30T17:14:18.767559Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable parameters: [('weight', 'val')]\n"
     ]
    }
   ],
   "source": [
    "class Scaler(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.weight = nn.Param(jnp.ones(3))      # trainable\n",
    "        self.gain = nn.Const(jnp.array(2.0))     # fixed\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return x * self.weight.value() * self.gain.value()\n",
    "\n",
    "model = Scaler()\n",
    "trainable = model.states(brainstate.ParamState)\n",
    "print('trainable parameters:', list(trainable.keys()))   # gain is absent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "918a295c",
   "metadata": {},
   "source": [
    "## Putting it together in a training step"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cf331e4",
   "metadata": {},
   "source": [
    "Constrained parameters and regularization compose with the ordinary\n",
    "`brainstate.transform.grad` workflow. The gradient is taken with respect to the unconstrained\n",
    "`ParamState`s, so updates can be applied directly; the constraints and penalties take care of\n",
    "themselves."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fe96f110",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:14:18.769833Z",
     "iopub.status.busy": "2026-05-30T17:14:18.769612Z",
     "iopub.status.idle": "2026-05-30T17:14:19.491775Z",
     "shell.execute_reply": "2026-05-30T17:14:19.490660Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss trajectory: [1.2443, 1.109, 1.0126, 0.9417, 0.8884]\n",
      "gain stays positive: True\n"
     ]
    }
   ],
   "source": [
    "class ConstrainedLinear(nn.Module):\n",
    "    def __init__(self, din, dout):\n",
    "        super().__init__()\n",
    "        self.w = nn.Param(brainstate.random.randn(din, dout) * 0.1, reg=nn.L2Reg(1e-3))\n",
    "        self.gain = nn.Param(jnp.array(1.0), t=nn.SoftplusT(lower=0.0))\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return (x @ self.w.value()) * self.gain.value()\n",
    "\n",
    "model = ConstrainedLinear(4, 2)\n",
    "params = model.states(brainstate.ParamState)\n",
    "x = brainstate.random.randn(16, 4)\n",
    "y = brainstate.random.randn(16, 2)\n",
    "\n",
    "def loss_fn():\n",
    "    mse = jnp.mean((model(x) - y) ** 2)\n",
    "    penalty = model.w.reg_loss()\n",
    "    return mse + u.get_magnitude(penalty)\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def train_step():\n",
    "    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()\n",
    "    for key in params:\n",
    "        params[key].value -= 0.1 * grads[key]\n",
    "    return loss\n",
    "\n",
    "losses = [float(train_step()) for _ in range(5)]\n",
    "print('loss trajectory:', [round(v, 4) for v in losses])\n",
    "print('gain stays positive:', float(u.get_magnitude(model.gain.value())) > 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "315b040a",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- `nn.Param(value, t=..., reg=...)` extends `ParamState` with a constraint transform and a\n",
    "  regularization prior.\n",
    "- `.value()` returns the **constrained** value (apply this in the forward pass); `.val` is the\n",
    "  underlying `ParamState` the optimizer updates in **unconstrained** space.\n",
    "- Transforms (`SoftplusT`, `SigmoidT`, `SimplexT`, `ChainT`, …) keep a parameter in its valid\n",
    "  domain through a smooth bijection, so gradients flow.\n",
    "- `.reg_loss()` returns the penalty for a regularized parameter (`L1Reg`, `L2Reg`,\n",
    "  `ElasticNetReg`, …); add it to the data loss.\n",
    "- `nn.Const` marks a value as non-trainable — it is excluded from the `ParamState` collection.\n",
    "\n",
    "### See also\n",
    "\n",
    "- [Observe and intercept state access with hooks](state_hooks.ipynb) — an imperative alternative for enforcing invariants on writes.\n",
    "- [Training and metrics](../tutorials/core/07_training_and_metrics.ipynb) — the full optimization loop these parameters slot into."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
