{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fourier Transform Functions\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/mathematical_functions/fft_functions.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/mathematical_functions/fft_functions.ipynb)\n",
    "\n",
    "`brainunit.fft` provides unit-aware Fast Fourier Transform functions.\n",
    "The FFT changes units because it involves an implicit integration over the transform variable:\n",
    "\n",
    "- **Changing unit**: `fft`, `ifft`, `rfft`, `irfft`, `fft2`, `ifft2`, `fftn`, `ifftn`, `rfft2`, `irfft2`, `rfftn`, `irfftn`, `fftfreq`, `rfftfreq`\n",
    "- **Keeping unit**: `fftshift`, `ifftshift`"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:42.023903800Z",
     "start_time": "2026-03-04T15:10:41.020670600Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1D FFT: `fft` and `ifft`\n",
    "\n",
    "The forward FFT of a signal with unit `u` produces a spectrum with unit `u * s` (multiplied by the sample-spacing unit).\n",
    "The inverse FFT reverses the operation, recovering the original unit."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:42.311550300Z",
     "start_time": "2026-03-04T15:10:42.055382100Z"
    }
   },
   "source": [
    "# A simple voltage signal\n",
    "signal = jnp.array([1., 2., 3., 4., 3., 2., 1., 0.]) * u.volt\n",
    "print('Signal:', signal)\n",
    "print('Signal unit:', signal.unit)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Signal: [1. 2. 3. 4. 3. 2. 1. 0.] V\n",
      "Signal unit: V\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:42.519369800Z",
     "start_time": "2026-03-04T15:10:42.312546900Z"
    }
   },
   "source": [
    "# Forward FFT\n",
    "spectrum = u.fft.fft(signal)\n",
    "print('Spectrum:', spectrum)\n",
    "print('Spectrum unit:', spectrum.unit)  # volt * second"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Spectrum: [16.       +0.j        -4.82842731-4.82842731j  0.       +0.j\n",
      "  0.82842708-0.82842708j  0.       +0.j         0.82842708+0.82842708j\n",
      "  0.       +0.j        -4.82842731+4.82842731j] Wb\n",
      "Spectrum unit: Wb\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:42.582203700Z",
     "start_time": "2026-03-04T15:10:42.519369800Z"
    }
   },
   "source": [
    "# Inverse FFT recovers original signal and unit\n",
    "recovered = u.fft.ifft(spectrum)\n",
    "print('Recovered:', recovered)\n",
    "print('Recovered unit:', recovered.unit)  # back to volt"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recovered: [1.+0.j 2.+0.j 3.+0.j 4.+0.j 3.+0.j 2.+0.j 1.+0.j 0.+0.j] V\n",
      "Recovered unit: V\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Real FFT: `rfft` and `irfft`\n",
    "\n",
    "For real-valued signals, `rfft` computes only the positive-frequency half of the spectrum\n",
    "(since the negative frequencies are conjugate symmetric)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:42.689635600Z",
     "start_time": "2026-03-04T15:10:42.583442900Z"
    }
   },
   "source": [
    "signal_real = jnp.array([1., 0., -1., 0., 1., 0., -1., 0.]) * u.ampere\n",
    "\n",
    "# rfft returns only positive frequencies (N//2 + 1 components)\n",
    "spec_real = u.fft.rfft(signal_real)\n",
    "print('rfft result:', spec_real)\n",
    "print('Length:', len(spec_real.mantissa), '(vs', len(signal_real.mantissa), 'input samples)')"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rfft result: [0.+0.j 0.+0.j 4.-0.j 0.+0.j 0.+0.j] C\n",
      "Length: 5 (vs 8 input samples)\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:42.883347300Z",
     "start_time": "2026-03-04T15:10:42.691753900Z"
    }
   },
   "source": [
    "# irfft recovers the original signal\n",
    "recovered_real = u.fft.irfft(spec_real)\n",
    "print('Recovered:', recovered_real)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recovered: [ 1.  0. -1.  0.  1.  0. -1.  0.] A\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Frequency Axes: `fftfreq` and `rfftfreq`\n",
    "\n",
    "These functions generate the frequency bin values corresponding to the FFT output.\n",
    "The `d` parameter is the sample spacing."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:43.318378700Z",
     "start_time": "2026-03-04T15:10:42.920001900Z"
    }
   },
   "source": [
    "n_samples = 8\n",
    "sample_spacing = 0.001  # 1 ms between samples (1000 Hz sampling rate)\n",
    "\n",
    "freqs = u.fft.fftfreq(n_samples, d=sample_spacing)\n",
    "print('FFT frequencies (Hz):', freqs)\n",
    "\n",
    "rfreqs = u.fft.rfftfreq(n_samples, d=sample_spacing)\n",
    "print('Real FFT frequencies (Hz):', rfreqs)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FFT frequencies (Hz): [   0.       124.99999  249.99998  374.99997 -499.99997 -374.99997\n",
      " -249.99998 -124.99999]\n",
      "Real FFT frequencies (Hz): [  0.      124.99999 249.99998 374.99997 499.99997]\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shifting: `fftshift` and `ifftshift`\n",
    "\n",
    "`fftshift` reorders the FFT output so that the zero-frequency component is in the center.\n",
    "These functions **keep the unit** unchanged."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:43.495661500Z",
     "start_time": "2026-03-04T15:10:43.348812500Z"
    }
   },
   "source": [
    "freqs = u.fft.fftfreq(8, d=0.1)\n",
    "print('Original order:', freqs)\n",
    "print('Shifted (zero-centered):', u.fft.fftshift(freqs))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original order: [ 0.    1.25  2.5   3.75 -5.   -3.75 -2.5  -1.25]\n",
      "Shifted (zero-centered): [-5.   -3.75 -2.5  -1.25  0.    1.25  2.5   3.75]\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:43.688213100Z",
     "start_time": "2026-03-04T15:10:43.522227Z"
    }
   },
   "source": [
    "# fftshift works on spectra with units too\n",
    "spec = u.fft.fft(jnp.array([1., 2., 3., 4.]) * u.volt)\n",
    "print('Spectrum:', spec)\n",
    "print('Shifted:', u.fft.fftshift(spec))\n",
    "print('Unit preserved:', u.fft.fftshift(spec).unit)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Spectrum: [10.+0.j -2.+2.j -2.+0.j -2.-2.j] Wb\n",
      "Shifted: [-2.+0.j -2.-2.j 10.+0.j -2.+2.j] Wb\n",
      "Unit preserved: Wb\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2D FFT: `fft2` and `ifft2`\n",
    "\n",
    "The 2D FFT applies the transform along two axes. The unit changes by multiplying\n",
    "with `s^2` (one factor of time per transformed dimension)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:43.947971900Z",
     "start_time": "2026-03-04T15:10:43.726791300Z"
    }
   },
   "source": [
    "# A 2D signal (e.g., a small image or spatial field)\n",
    "field = jnp.array([\n",
    "    [1., 2., 3., 4.],\n",
    "    [5., 6., 7., 8.],\n",
    "    [9., 10., 11., 12.],\n",
    "    [13., 14., 15., 16.]\n",
    "]) * u.pascal\n",
    "\n",
    "spec_2d = u.fft.fft2(field)\n",
    "print('2D FFT result:')\n",
    "print(spec_2d)\n",
    "print('Unit:', spec_2d.unit)  # pascal * s^2"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2D FFT result:\n",
      "[[136. +0.j  -8. +8.j  -8. +0.j  -8. -8.j]\n",
      " [-32.+32.j   0. +0.j   0. +0.j   0. +0.j]\n",
      " [-32. +0.j   0. +0.j   0. +0.j   0. +0.j]\n",
      " [-32.-32.j   0. +0.j   0. +0.j   0. +0.j]] Pa * s^2\n",
      "Unit: Pa * s^2\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:43.998439100Z",
     "start_time": "2026-03-04T15:10:43.948980300Z"
    }
   },
   "source": [
    "# Inverse 2D FFT\n",
    "recovered_2d = u.fft.ifft2(spec_2d)\n",
    "print('Recovered 2D signal:')\n",
    "print(recovered_2d)\n",
    "print('Unit:', recovered_2d.unit)  # back to pascal"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recovered 2D signal:\n",
      "[[ 1.+0.j  2.+0.j  3.+0.j  4.+0.j]\n",
      " [ 5.+0.j  6.+0.j  7.+0.j  8.+0.j]\n",
      " [ 9.+0.j 10.+0.j 11.+0.j 12.+0.j]\n",
      " [13.+0.j 14.+0.j 15.+0.j 16.+0.j]] Pa\n",
      "Unit: Pa\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## N-D FFT: `fftn` and `ifftn`\n",
    "\n",
    "Generalization to arbitrary dimensions."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:44.079896800Z",
     "start_time": "2026-03-04T15:10:43.998439100Z"
    }
   },
   "source": [
    "# 3D data\n",
    "data_3d = jnp.ones((2, 3, 4)) * u.meter\n",
    "spec_3d = u.fft.fftn(data_3d)\n",
    "print('3D FFT shape:', spec_3d.shape)\n",
    "print('3D FFT unit:', spec_3d.unit)  # meter * s^3"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3D FFT shape: (2, 3, 4)\n",
      "3D FFT unit: m * s^3\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:44.196768900Z",
     "start_time": "2026-03-04T15:10:44.081408Z"
    }
   },
   "source": [
    "# Transform along specific axes only\n",
    "spec_partial = u.fft.fftn(data_3d, axes=(0, 1))  # transform first 2 axes only\n",
    "print('Partial FFT unit:', spec_partial.unit)  # meter * s^2"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Partial FFT unit: m * s^2\n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Practical Example: Spectral Analysis of a Signal\n",
    "\n",
    "Analyze the frequency content of a composite voltage signal."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:44.404646700Z",
     "start_time": "2026-03-04T15:10:44.215259Z"
    }
   },
   "source": [
    "# Generate a signal: sum of two sinusoids\n",
    "n = 256\n",
    "dt = 0.001  # 1 ms sample spacing -> 1000 Hz sampling rate\n",
    "t = jnp.arange(n) * dt  # time in seconds\n",
    "\n",
    "# 50 Hz and 120 Hz components\n",
    "signal_composed = (1.0 * jnp.sin(2 * jnp.pi * 50 * t) +\n",
    "                   0.5 * jnp.sin(2 * jnp.pi * 120 * t)) * u.volt\n",
    "\n",
    "print('Signal shape:', signal_composed.shape)\n",
    "print('Signal unit:', signal_composed.unit)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Signal shape: (256,)\n",
      "Signal unit: V\n"
     ]
    }
   ],
   "execution_count": 14
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:44.655409400Z",
     "start_time": "2026-03-04T15:10:44.405644800Z"
    }
   },
   "source": [
    "# Compute spectrum\n",
    "spectrum_composed = u.fft.rfft(signal_composed)\n",
    "freqs_composed = u.fft.rfftfreq(n, d=dt)\n",
    "\n",
    "# Power spectrum (magnitude squared)\n",
    "power = u.math.abs(spectrum_composed)\n",
    "print('Frequency bins:', freqs_composed[:5], '...')\n",
    "print('Power at DC:', power.mantissa[0])\n",
    "print('Number of frequency bins:', len(freqs_composed))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Frequency bins: [ 0.         3.9062498  7.8124995 11.718749  15.624999 ] ...\n",
      "Power at DC: 3.6521716\n",
      "Number of frequency bins: 129\n"
     ]
    }
   ],
   "execution_count": 15
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "| Function | Unit Change | Description |\n",
    "|----------|------------|-------------|\n",
    "| `fft(x)` | `u → u*s` | Forward 1D FFT |\n",
    "| `ifft(X)` | `u*s → u` | Inverse 1D FFT |\n",
    "| `rfft(x)` | `u → u*s` | Real 1D FFT (positive freq only) |\n",
    "| `irfft(X)` | `u*s → u` | Inverse real 1D FFT |\n",
    "| `fft2(x)` | `u → u*s^2` | Forward 2D FFT |\n",
    "| `fftn(x)` | `u → u*s^N` | Forward N-D FFT |\n",
    "| `fftfreq(n, d)` | dimensionless | Frequency bin values |\n",
    "| `fftshift(x)` | keeps unit | Zero-center the spectrum |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
