{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Typing System\n",
   "id": "d08e7c814308be13"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "This notebook introduces the type utilities in `brainstate.typing`.\n",
    "You will learn how to annotate arrays, PyTrees, random seeds, and helper\n",
    "structures so that static checkers and collaborators can understand your\n",
    "code more easily.\n",
    "\n",
    "Topics covered:\n",
    "\n",
    "- Size/shape/axis aliases used in array APIs.\n",
    "- `Array` / `ArrayLike` for expressing tensor expectations.\n",
    "- `PyTree` annotations and path filters for tree utilities.\n",
    "- Data type helpers (`DType`, `DTypeLike`, `SupportsDType`).\n",
    "- Random key, sentinel, and filter helper types.\n"
   ],
   "id": "9cf7e2f2a1a6ad6"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from typing import Any\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "from brainstate.typing import (\n",
    "    Array,\n",
    "    ArrayLike,\n",
    "    Axes,\n",
    "    DType,\n",
    "    DTypeLike,\n",
    "    Filter,\n",
    "    Key,\n",
    "    Missing,\n",
    "    PathParts,\n",
    "    Predicate,\n",
    "    PyTree,\n",
    "    SeedOrKey,\n",
    "    Shape,\n",
    "    Size,\n",
    "    SupportsDType,\n",
    ")\n"
   ],
   "id": "fe2a8bcfbd579c4c"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Shapes, sizes, and axes\n",
    "\n",
    "`Size`, `Shape`, and `Axes` help you document functions that expect\n",
    "specific tensor dimensions. They are thin aliases around Python sequences\n",
    "but communicating intent through annotations is valuable to readers and\n",
    "tooling.\n"
   ],
   "id": "444845299cca73ad"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def normalise_batch(batch: ArrayLike, shape: Shape, along: Axes = 0) -> jax.Array:\n",
    "    \"\"\"Reshape `batch` then standardise along the given axes.\"\"\"\n",
    "    array = jnp.asarray(batch).reshape(tuple(shape))\n",
    "    mean = jnp.mean(array, axis=along, keepdims=True)\n",
    "    std = jnp.maximum(jnp.std(array, axis=along, keepdims=True), 1e-6)\n",
    "    return (array - mean) / std\n",
    "\n",
    "example = normalise_batch(jnp.arange(12.0), shape=(3, 4), along=0)\n",
    "example\n"
   ],
   "id": "3f49b960a17f80e4"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Array annotations\n",
    "\n",
    "Use `Array[...]` to describe shape expectations and `ArrayLike` when a\n",
    "function accepts anything convertible to a JAX array. These annotations are\n",
    "informative for readers and static type checkers alike.\n"
   ],
   "id": "e39233ab282cd0ad"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "Matrix = Array[\"rows, cols\"]\n",
    "Vector = Array[\"cols\"]\n",
    "\n",
    "def affine_transform(x: Matrix, weight: Array[\"cols, features\"], bias: Vector) -> Array[\"rows, features\"]:\n",
    "    return x @ weight + bias\n",
    "\n",
    "x = jnp.ones((2, 3))\n",
    "w = jnp.arange(6.0).reshape(3, 2)\n",
    "b = jnp.array([0.5, -0.5])\n",
    "affine_transform(x, w, b)\n"
   ],
   "id": "f136e2cd8ad3d697"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "You can still accept flexible data by annotating parameters as `ArrayLike`.\n",
    "The conversion to `jnp.asarray` happens inside the function, keeping the\n",
    "        signature expressive yet ergonomic.\n"
   ],
   "id": "6e00c996f03c5724"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def sum_energy(signal: ArrayLike) -> float:\n",
    "    arr = jnp.asarray(signal)\n",
    "    return float(jnp.sum(arr ** 2))\n",
    "\n",
    "print(sum_energy([1, 2, 3]))\n",
    "print(sum_energy(np.float32(1.5)))\n"
   ],
   "id": "922e82d4bc43ec9d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "> `ArrayLike` also covers `brainunit.Quantity` objects, so unit-aware\n",
    "        tensors can pass through the same APIs without losing type information.\n"
   ],
   "id": "1a600a9839021e33"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Annotating PyTrees\n",
    "\n",
    "`PyTree` acts like `typing.Any`, but it documents the expected leaf type\n",
    "        (and optionally structure). That improves readability when writing\n",
    "        utilities that operate on nested containers.\n"
   ],
   "id": "2a4d66125c562151"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def tree_l2_norm(tree: PyTree[jax.Array]) -> float:\n",
    "    leaves, _ = jax.tree_util.tree_flatten(tree)\n",
    "    total = sum(float(jnp.sum(jnp.square(jnp.asarray(leaf)))) for leaf in leaves)\n",
    "    return float(total)\n",
    "\n",
    "nested = {\"encoder\": jnp.ones((2, 2)), \"decoder\": [jnp.arange(3.0)]}\n",
    "tree_l2_norm(nested)\n"
   ],
   "id": "4f34c2c22d970efc"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Working with paths and filters\n",
    "\n",
    "`PathParts`, `Predicate`, and `Filter` describe how to select parts of a\n",
    "        PyTree. The snippet below collects leaves whose path ends with `\"weight\"`.\n"
   ],
   "id": "48d42c0e07db9329"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def walk(tree: Any, predicate: Predicate, path: PathParts = ()) -> list[tuple[PathParts, Any]]:\n",
    "    matches: list[tuple[PathParts, Any]] = []\n",
    "    if predicate(path, tree):\n",
    "        matches.append((path, tree))\n",
    "    if isinstance(tree, dict):\n",
    "        for key, value in tree.items():\n",
    "            matches.extend(walk(value, predicate, path + (key,)))\n",
    "    elif isinstance(tree, (list, tuple)):\n",
    "        for idx, value in enumerate(tree):\n",
    "            matches.extend(walk(value, predicate, path + (idx,)))\n",
    "    return matches\n",
    "\n",
    "model = {\n",
    "    \"dense1\": {\"weight\": jnp.ones((3, 3)), \"bias\": jnp.zeros(3)},\n",
    "    \"dense2\": {\"weight\": jnp.eye(3), \"bias\": jnp.ones(3)},\n",
    "}\n",
    "\n",
    "weight_filter: Predicate = lambda path, value: path and path[-1] == \"weight\"\n",
    "for found_path, value in walk(model, weight_filter):\n",
    "    print(found_path, value.shape)\n"
   ],
   "id": "6d69902355720c67"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Data type helpers\n",
    "\n",
    "`DType` names a concrete NumPy dtype, while `DTypeLike` accepts any object\n",
    "        that can be coerced into one. Implementing the `SupportsDType` protocol\n",
    "        lets custom containers participate too.\n"
   ],
   "id": "50a26421b0b329ca"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "class TensorView(SupportsDType):\n",
    "    def __init__(self, array: jax.Array):\n",
    "        self._array = array\n",
    "\n",
    "    @property\n",
    "    def dtype(self) -> DType:\n",
    "        return self._array.dtype\n",
    "\n",
    "def zeros_like(shape: Shape, dtype: DTypeLike) -> jax.Array:\n",
    "    return jnp.zeros(shape, dtype=dtype)\n",
    "\n",
    "print(zeros_like((2, 2), np.float32))\n",
    "print(zeros_like((1, 3), TensorView(jnp.ones(3))))\n"
   ],
   "id": "3c2765c762b26370"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Random seeds and keys\n",
    "\n",
    "`SeedOrKey` lists the accepted random sources (`int`, JAX key, or NumPy key).\n",
    "Normalising the input inside your function keeps call sites ergonomic.\n"
   ],
   "id": "132d022f9f820126"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def sample_normal(key: SeedOrKey, shape: Shape) -> jax.Array:\n",
    "    if isinstance(key, int):\n",
    "        key = jax.random.PRNGKey(key)\n",
    "    elif isinstance(key, np.ndarray):\n",
    "        key = jnp.asarray(key, dtype=jnp.uint32)\n",
    "    return jax.random.normal(key, shape)\n",
    "\n",
    "print(sample_normal(0, (2,)))\n",
    "print(sample_normal(jax.random.PRNGKey(1), (2,)))\n"
   ],
   "id": "efbf785076c79d0f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Keys and sentinels\n",
    "\n",
    "`Key` is a protocol for path components. `Missing` is a sentinel object you can\n",
    "        use when `None` is a meaningful value.\n"
   ],
   "id": "8c9bc915b373de9c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "_MISSING = Missing()\n",
    "\n",
    "def resolve_config(name: Key, *, output_dir: str | Missing = _MISSING) -> str:\n",
    "    if output_dir is _MISSING:\n",
    "        return f'/tmp/{name}'\n",
    "    return str(output_dir)\n",
    "\n",
    "print(resolve_config('experiment-A'))\n",
    "print(resolve_config('experiment-B', output_dir=None))\n"
   ],
   "id": "f9a018f8983d33a2"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Summary\n",
    "\n",
    "BrainState's typing helpers build on standard Python typing to describe arrays,\n",
    "        PyTrees, dtypes, random keys, and structural filters. Applying them consistently\n",
    "        makes complex scientific code easier to navigate and verify.\n"
   ],
   "id": "2306a4978deb8c84"
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}
