{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9df30c963df1",
   "metadata": {},
   "source": [
    "# Dask backend\n",
    "\n",
    "[Dask](https://docs.dask.org/) provides parallel, out-of-core arrays.\n",
    "brainunit accepts a `dask.array.Array` mantissa and keeps operations lazy:\n",
    "building a quantity, arithmetic, and most `brainunit.math` / `brainunit.linalg`\n",
    "operations do not trigger a compute. Use it for arrays that don't fit in\n",
    "memory or for embarrassingly parallel array work on a cluster.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ae7be41e9c5",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "```bash\n",
    "pip install brainunit[dask]\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "774379ee8b31",
   "metadata": {},
   "source": [
    "## Graceful import\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "d879ec439bff",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:11:45.295993200Z",
     "start_time": "2026-05-22T03:11:42.363242400Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "\n",
    "try:\n",
    "    import dask.array as da\n",
    "    HAVE_DASK = True\n",
    "except ImportError:\n",
    "    HAVE_DASK = False\n",
    "    print('dask not installed; install with: pip install brainunit[dask]')\n"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "id": "f18e931fc71b",
   "metadata": {},
   "source": [
    "## Quick start — lazy by default\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "755149f1b0f7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:11:47.956000300Z",
     "start_time": "2026-05-22T03:11:45.656910500Z"
    }
   },
   "source": [
    "if HAVE_DASK:\n",
    "    import numpy as np\n",
    "    big = da.from_array(np.arange(1_000_000.0), chunks=100_000)\n",
    "    q = u.Quantity(big, unit=u.meter)\n",
    "    print('backend =', q.backend)\n",
    "    print('shape   =', q.shape)        # no compute\n",
    "    print('lazy add:', (q + q).backend)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "backend = dask\n",
      "shape   = (1000000,)\n",
      "lazy add: dask\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "id": "4fdd9d63653c",
   "metadata": {},
   "source": [
    "## What requires compute\n",
    "\n",
    "Operations that need a Python scalar — `float(q)`, `int(q)`, `q.tolist()`,\n",
    "`np.asarray(q)`, `hash(q)`, `operator.index(q)` — raise `BackendError`. Call\n",
    "`q.mantissa.compute()` first, then materialize.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "f90e1701197a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:11:47.993161100Z",
     "start_time": "2026-05-22T03:11:47.965287500Z"
    }
   },
   "source": [
    "if HAVE_DASK:\n",
    "    import numpy as np\n",
    "    from brainunit import BackendError\n",
    "\n",
    "    single = u.Quantity(da.from_array(np.array([42.0]), chunks=1), unit=u.meter)\n",
    "    try:\n",
    "        float(single)\n",
    "    except BackendError as exc:\n",
    "        print('expected:', exc)\n",
    "\n",
    "    # materialize first\n",
    "    eager_mantissa = single.mantissa.compute()\n",
    "    print('after compute:', u.Quantity(eager_mantissa, unit=u.meter) / u.meter)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "expected: float(Quantity) would materialize a dask-backed Quantity. Call `q.mantissa.compute()` first.\n",
      "after compute: [42.]\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "id": "dc4534d44d02",
   "metadata": {},
   "source": [
    "## Conversion\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "d67a8a7ada94",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:11:48.032728600Z",
     "start_time": "2026-05-22T03:11:48.011354200Z"
    }
   },
   "source": [
    "if HAVE_DASK:\n",
    "    import numpy as np\n",
    "    q_np = u.Quantity(np.arange(1_000_000.0), unit=u.meter)\n",
    "    q_da = q_np.to_dask(chunks=100_000)\n",
    "    print(q_da.backend, q_da.mantissa.chunks)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dask ((100000, 100000, 100000, 100000, 100000, 100000, 100000, 100000, 100000, 100000),)\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "id": "0547af9edb26",
   "metadata": {},
   "source": [
    "## Mixed-backend arithmetic\n",
    "\n",
    "Mixing dask and non-dask quantities lands on the default-backend tiebreaker.\n",
    "If the result lands on dask, the non-dask operand is auto-lifted.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "021569e42dcc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:11:48.084716800Z",
     "start_time": "2026-05-22T03:11:48.051726800Z"
    }
   },
   "source": [
    "if HAVE_DASK:\n",
    "    import numpy as np\n",
    "    q_da = u.Quantity(da.from_array(np.array([1.0, 2.0]), chunks=1), unit=u.meter)\n",
    "    q_np = u.Quantity(np.array([3.0, 4.0]), unit=u.meter)\n",
    "    with u.using_backend('dask'):\n",
    "        result = q_da + q_np\n",
    "        print(result.backend)            # 'dask'\n",
    "        print(result.mantissa.compute())\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dask\n",
      "[4. 6.]\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "id": "96f47fd8defb",
   "metadata": {},
   "source": [
    "## Limitations\n",
    "\n",
    "- `brainunit.autograd`, `brainunit.lax`, `brainunit.sparse` are JAX-only.\n",
    "- Operations needing a concrete value require an explicit `.compute()`.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
