{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Unit-aware Computation with ``CustomArray``\n\n[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/advanced_tutorials/custom_array.ipynb)\n[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/advanced_tutorials/custom_array.ipynb)\n\n## Introduction\n\nThe ``CustomArray`` class in brainunit provides a practical foundation for creating unit-aware computational arrays that maintain dimensional consistency throughout complex calculations. This tutorial shows how to use ``CustomArray`` to build array types that automatically handle units, enabling safer and more maintainable scientific computing.\n\n### What Is Unit-aware Computation?\n\nUnit-aware computation keeps physical quantities dimensionally correct across operations. Typical rules:\n- Adding meters to meters results in meters\n- Multiplying meters by meters results in square meters\n- Dividing distance by time results in velocity\n- Invalid operations (e.g., meters + seconds) are detected and raise errors\n\n### Why Use CustomArray?\n\n- Type safety: Prevents dimensional errors at runtime\n- Automatic unit propagation through operations\n- Works with NumPy and JAX arrays (and supports PyTorch-like methods)\n- Extensible: create domain-specific array types (physics, neuroscience, etc.)\n- Minimal overhead compared to raw arrays",
   "id": "6b5e751aee5901e2"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:09.779098700Z",
     "start_time": "2026-03-04T15:10:07.852550800Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Imports\n",
    "import brainunit as u\n",
    "import brainstate\n",
    "\n",
    "print(\"brainunit version:\", getattr(u, '__version__', 'unknown'))\n",
    "print('Sample units:', 'm, s, Hz, V, A, kg, N, Pa, J')"
   ],
   "id": "2644f9d8a3d558c7",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "brainunit version: 0.2.0\n",
      "Sample units: m, s, Hz, V, A, kg, N, Pa, J\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## CustomArray Architecture\n",
    "\n",
    "``CustomArray`` is a base class. Any class that inherits from it and provides a ``.data`` attribute automatically gains rich array behavior and unit-aware math via ``brainunit.math``.\n",
    "\n",
    "Core requirements:\n",
    "1. Inherit from ``u.CustomArray``\n",
    "2. Store your underlying data (with units) in ``self.data``\n",
    "\n",
    "Benefits:\n",
    "- Separation of concerns: you focus on data/state, ``CustomArray`` handles array ops\n",
    "- Unit propagation: math operations keep correct units\n",
    "- Backend flexibility: ``self.data`` can be NumPy, JAX, or other array-likes"
   ],
   "id": "b0619fe71606014"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:10.071614800Z",
     "start_time": "2026-03-04T15:10:09.780098100Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# A minimal, practical CustomArray\n",
    "class MyArray(u.CustomArray):\n",
    "    \"\"\"Minimal unit-aware array: just store a `.data`.\"\"\"\n",
    "    def __init__(self, data):\n",
    "        self.data = data  # typically a brainunit Quantity or plain array\n",
    "    def __repr__(self):\n",
    "        return f'MyArray({self.data})'\n",
    "\n",
    "# Create an instance with units\n",
    "length = MyArray([1, 2, 3] * u.meter)\n",
    "length, length.shape, getattr(length.data, 'unit', 'unitless')"
   ],
   "id": "6e32bcbf88e04c37",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(MyArray([1 2 3] m), (3,), Unit(\"m\"))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Unit Propagation with Operators\n",
    "\n",
    "When ``.data`` is a ``Quantity``, standard operations automatically keep or change units correctly."
   ],
   "id": "cab5d3d499d5743d"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:10.212397900Z",
     "start_time": "2026-03-04T15:10:10.094272100Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Compatible addition keeps units\n",
    "length_cm = MyArray([100, 200, 300] * u.cmeter)\n",
    "total_length = length + length_cm  # meters + centimeters -> meters\n",
    "print('total_length:', total_length)"
   ],
   "id": "f86a48be140280e",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total_length: [2. 4. 6.] m\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:10.253059900Z",
     "start_time": "2026-03-04T15:10:10.213452400Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Multiplication changes units (area)\n",
    "area = length * length  # m * m -> m^2\n",
    "print('area:', area)"
   ],
   "id": "e7d47dbb7d4f231b",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "area: [1 4 9] m^2\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:10.338118100Z",
     "start_time": "2026-03-04T15:10:10.253059900Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Division changes units (velocity)\n",
    "time = MyArray([1, 2, 3] * u.second)\n",
    "velocity = length / time  # m / s\n",
    "print('velocity:', velocity)"
   ],
   "id": "394639dab7897a0f",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "velocity: [1. 1. 1.] m / s\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:10.352501600Z",
     "start_time": "2026-03-04T15:10:10.339122900Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Incompatible addition raises an error\n",
    "try:\n",
    "    bad = length + time\n",
    "except Exception as e:\n",
    "    print('Expected error:', e)"
   ],
   "id": "48e1b9f368598e54",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected error: Cannot calculate \n",
      "[1 2 3] m + [1 2 3] s, because units do not match: m != s\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Using ``brainunit.math`` with CustomArray\n",
    "\n",
    "The ``brainunit.math`` module mirrors NumPy/JAX APIs and is unit-aware. All functions accept ``CustomArray`` instances: internally, brainunit extracts ``.data`` via helper utilities and returns quantities with correct units.\n",
    "\n",
    "Categories (simplified):\n",
    "- Keep-unit functions (e.g., ``mean``, ``sum``, ``concatenate``, ``stack``) return the same unit\n",
    "- Change-unit functions (e.g., ``square``, ``sqrt``, ``multiply``, ``divide``, ``var``) transform units according to math rules\n",
    "- Some functions require unitless inputs (e.g., ``round``, ``floor``)"
   ],
   "id": "4fe375e2884c8e50"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:10.693638800Z",
     "start_time": "2026-03-04T15:10:10.353493Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Keep-unit examples\n",
    "print('mean(length):', u.math.mean(length))\n",
    "print('sum(length):', u.math.sum(length))\n",
    "\n",
    "# Change-unit examples\n",
    "print('square(length):', u.math.square(length))  # m^2\n",
    "print('sqrt(square(length)):', u.math.sqrt(u.math.square(length)))  # back to m\n",
    "print('var(length):', u.math.var(length))  # m^2\n",
    "\n",
    "# Broadcasting and stacking\n",
    "stacked = u.math.stack([length, length_cm])\n",
    "print('stacked shape:', getattr(stacked, 'shape', None))\n",
    "\n",
    "# Linear algebra with units\n",
    "force = MyArray([10, 20, 30] * u.newton)\n",
    "displacement = MyArray([0.5, 1.0, 1.5] * u.meter)\n",
    "work = u.math.dot(force, displacement)  # N·m -> J (joule)\n",
    "print('work (dot):', work)"
   ],
   "id": "fdc4078c0b87d410",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean(length): 2. m\n",
      "sum(length): 6 m\n",
      "square(length): [1 4 9] m^2\n",
      "sqrt(square(length)): [1. 2. 3.] m\n",
      "var(length): 0.6666667 m^2\n",
      "stacked shape: (2, 3)\n",
      "work (dot): 70. J\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Converting Units for Display or Interop\n\nUse ``Quantity.to_decimal(target_unit)`` to get values in a desired unit scale for display, logging, or plotting.",
   "id": "ecf7702616b75edc"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:10.715065900Z",
     "start_time": "2026-03-04T15:10:10.698650100Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Convert quantity values inside your CustomArray for display\n",
    "meters = MyArray([1, 2, 3] * u.meter)\n",
    "print('as meters:', meters.data.to_decimal(u.meter))\n",
    "print('as centimeters:', meters.data.to_decimal(u.cmeter))"
   ],
   "id": "f343d39d5f66e4ca",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "as meters: [1 2 3]\n",
      "as centimeters: [100. 200. 300.]\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Stateful and Learnable Arrays (BrainState)\n\nFor stateful workflows, combine ``brainstate.State`` with ``CustomArray`` to create learnable, unit-aware parameters.",
   "id": "7adf443c420db922"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:10.744666500Z",
     "start_time": "2026-03-04T15:10:10.716066300Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class StatefulArray(brainstate.State, u.CustomArray):\n",
    "    @property\n",
    "    def data(self):\n",
    "        return self.value\n",
    "\n",
    "# Example: a learnable parameter with units\n",
    "param = StatefulArray(0.1 * u.second)\n",
    "print('stateful param:', param)"
   ],
   "id": "3c7638daebffaee1",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stateful param: 0.1 s\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Robust Patterns and Error Handling\n\nTips:\n- Document expected units for each array (e.g., meters for length)\n- Validate inputs when building domain-specific types\n- Catch and surface unit mismatch errors with clear messages\n- Prefer ``brainunit.math`` over raw NumPy for unit-aware operations",
   "id": "2fee55616276d420"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Summary\n",
    "\n",
    "- Inherit from ``u.CustomArray`` and set ``self.data`` (often a ``Quantity``)\n",
    "- Use operators and ``brainunit.math`` to get automatic unit propagation\n",
    "- Convert units with ``Quantity.to_decimal`` for display or interop\n",
    "- Combine with BrainState to build stateful, unit-aware components\n",
    "\n",
    "With these patterns, you can build reliable, unit-safe computational workflows across NumPy and JAX backends."
   ],
   "id": "afe05227e45698b8"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
