{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Activation Functions\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/mathematical_functions/activation_functions.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/mathematical_functions/activation_functions.ipynb)\n",
    "\n",
    "`brainunit.math` provides 20+ unit-aware activation functions for neural networks.\n",
    "These functions fall into two categories:\n",
    "\n",
    "- **Piecewise-linear** (keep unit): `relu`, `leaky_relu` — work directly with Quantity inputs\n",
    "- **Nonlinear** (require unitless): `sigmoid`, `gelu`, `softplus`, `tanh`, etc. — require dimensionless input, but support `unit_to_scale` for automatic conversion"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "start_time": "2026-03-04T15:10:24.969585Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Activations That Keep Unit\n",
    "\n",
    "Piecewise-linear activations preserve the input unit because they only apply \n",
    "thresholding and scaling — no transcendental functions involved."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `relu` — Rectified Linear Unit"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "x = jnp.array([-2., -1., 0., 1., 2.]) * u.volt\n",
    "\n",
    "print('relu:', u.math.relu(x))  # [0, 0, 0, 1, 2] V"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `leaky_relu` — Leaky ReLU"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Allows small negative values (default slope = 0.01)\n",
    "print('leaky_relu:', u.math.leaky_relu(x))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Activations Requiring Dimensionless Input\n",
    "\n",
    "Most activation functions (sigmoid, tanh, gelu, etc.) involve exponentials or other\n",
    "transcendental functions that are not physically meaningful with units.\n",
    "\n",
    "There are two ways to use them:\n",
    "\n",
    "1. **Dimensionless input**: Pass plain arrays or convert to dimensionless first\n",
    "2. **`unit_to_scale` parameter**: Automatically converts the Quantity to dimensionless using the given unit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `sigmoid`"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Method 1: Dimensionless input\n",
    "x_unitless = jnp.array([-2., -1., 0., 1., 2.])\n",
    "print('sigmoid (unitless):', u.math.sigmoid(x_unitless))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Method 2: unit_to_scale — auto-converts Quantity to dimensionless\n",
    "x_volts = jnp.array([-2., -1., 0., 1., 2.]) * u.volt\n",
    "print('sigmoid (with unit_to_scale):', u.math.sigmoid(x_volts, unit_to_scale=u.volt))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Without unit_to_scale, passing a Quantity raises an error\n",
    "try:\n",
    "    u.math.sigmoid(x_volts)\n",
    "except Exception as e:\n",
    "    print(type(e).__name__, ':', str(e)[:120])"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `tanh`"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print('tanh (unitless):', u.math.tanh(x_unitless))\n",
    "print('tanh (unit_to_scale):', u.math.tanh(x_volts, unit_to_scale=u.volt))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `gelu` — Gaussian Error Linear Unit"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print('gelu:', u.math.gelu(x_unitless))\n",
    "print('gelu (unit_to_scale):', u.math.gelu(x_volts, unit_to_scale=u.volt))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `softplus`"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print('softplus:', u.math.softplus(x_unitless))\n",
    "print('softplus (unit_to_scale):', u.math.softplus(x_volts, unit_to_scale=u.volt))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `silu` / `swish` — Sigmoid Linear Unit"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print('silu:', u.math.silu(x_unitless))\n",
    "print('silu (unit_to_scale):', u.math.silu(x_volts, unit_to_scale=u.volt))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `elu` — Exponential Linear Unit"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print('elu:', u.math.elu(x_unitless))\n",
    "print('elu (unit_to_scale):', u.math.elu(x_volts, unit_to_scale=u.volt))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### More Activations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# All these work the same way with unit_to_scale\n",
    "print('celu:', u.math.celu(x_unitless))\n",
    "print('selu:', u.math.selu(x_unitless))\n",
    "print('mish:', u.math.mish(x_unitless))\n",
    "print('hard_sigmoid:', u.math.hard_sigmoid(x_unitless))\n",
    "print('hard_tanh:', u.math.hard_tanh(x_unitless))\n",
    "print('squareplus:', u.math.squareplus(x_unitless))\n",
    "print('soft_sign:', u.math.soft_sign(x_unitless))\n",
    "print('log_sigmoid:', u.math.log_sigmoid(x_unitless))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Practical Example: Unit-Aware Neural Network Layer\n",
    "\n",
    "In scientific applications, you might want a neural network layer that respects\n",
    "physical units — for example, converting membrane voltages through an activation."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Simulate a simple neural activation\n",
    "# Membrane potentials in millivolts\n",
    "V_membrane = jnp.array([-80., -65., -50., -30., 0., 20.]) * u.mV\n",
    "\n",
    "# ReLU-based firing rate (keeps unit)\n",
    "# Neurons fire only for positive membrane potential\n",
    "firing_rate_relu = u.math.relu(V_membrane)\n",
    "print('ReLU output:', firing_rate_relu)\n",
    "\n",
    "# Sigmoid-based firing probability (dimensionless, 0 to 1)\n",
    "# Convert mV to dimensionless using unit_to_scale\n",
    "firing_prob = u.math.sigmoid(V_membrane, unit_to_scale=u.mV)\n",
    "print('Sigmoid probability:', firing_prob)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "| Category | Functions | With Units? | With `unit_to_scale`? |\n",
    "|----------|----------|-------------|----------------------|\n",
    "| **Keep unit** | `relu`, `leaky_relu` | Yes | N/A |\n",
    "| **Require unitless** | `sigmoid`, `tanh`, `gelu`, `softplus`, `silu`, `elu`, `celu`, `selu`, `mish`, `hard_sigmoid`, `hard_tanh`, `soft_sign`, `squareplus`, `log_sigmoid`, `relu6` | No (dimensionless only) | Yes |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
