Fourier Transform Functions#

Colab Open in Kaggle

brainunit.fft provides unit-aware Fast Fourier Transform functions. The FFT changes units because it involves an implicit integration over the transform variable:

  • Changing unit: fft, ifft, rfft, irfft, fft2, ifft2, fftn, ifftn, rfft2, irfft2, rfftn, irfftn, fftfreq, rfftfreq

  • Keeping unit: fftshift, ifftshift

import brainunit as u
import jax.numpy as jnp

1D FFT: fft and ifft#

The forward FFT of a signal with unit u produces a spectrum with unit u * s (multiplied by the sample-spacing unit). The inverse FFT reverses the operation, recovering the original unit.

# A simple voltage signal
signal = jnp.array([1., 2., 3., 4., 3., 2., 1., 0.]) * u.volt
print('Signal:', signal)
print('Signal unit:', signal.unit)
Signal: [1. 2. 3. 4. 3. 2. 1. 0.] V
Signal unit: V
# Forward FFT
spectrum = u.fft.fft(signal)
print('Spectrum:', spectrum)
print('Spectrum unit:', spectrum.unit)  # volt * second
Spectrum: [16.       +0.j        -4.82842731-4.82842731j  0.       +0.j
  0.82842708-0.82842708j  0.       +0.j         0.82842708+0.82842708j
  0.       +0.j        -4.82842731+4.82842731j] Wb
Spectrum unit: Wb
# Inverse FFT recovers original signal and unit
recovered = u.fft.ifft(spectrum)
print('Recovered:', recovered)
print('Recovered unit:', recovered.unit)  # back to volt
Recovered: [1.+0.j 2.+0.j 3.+0.j 4.+0.j 3.+0.j 2.+0.j 1.+0.j 0.+0.j] V
Recovered unit: V

Real FFT: rfft and irfft#

For real-valued signals, rfft computes only the positive-frequency half of the spectrum (since the negative frequencies are conjugate symmetric).

signal_real = jnp.array([1., 0., -1., 0., 1., 0., -1., 0.]) * u.ampere

# rfft returns only positive frequencies (N//2 + 1 components)
spec_real = u.fft.rfft(signal_real)
print('rfft result:', spec_real)
print('Length:', len(spec_real.mantissa), '(vs', len(signal_real.mantissa), 'input samples)')
rfft result: 
[0.+0.j 0.+0.j 4.-0.j 0.+0.j 0.+0.j] C
Length: 5 (vs 8 input samples)
# irfft recovers the original signal
recovered_real = u.fft.irfft(spec_real)
print('Recovered:', recovered_real)
Recovered: [ 1.  0. -1.  0.  1.  0. -1.  0.] A

Frequency Axes: fftfreq and rfftfreq#

These functions generate the frequency bin values corresponding to the FFT output. The d parameter is the sample spacing.

n_samples = 8
sample_spacing = 0.001  # 1 ms between samples (1000 Hz sampling rate)

freqs = u.fft.fftfreq(n_samples, d=sample_spacing)
print('FFT frequencies (Hz):', freqs)

rfreqs = u.fft.rfftfreq(n_samples, d=sample_spacing)
print('Real FFT frequencies (Hz):', rfreqs)
FFT frequencies (Hz): [   0.       124.99999  249.99998  374.99997 -499.99997 -374.99997
 -249.99998 -124.99999]
Real FFT frequencies (Hz): [  0.      124.99999 249.99998 374.99997 499.99997]

Shifting: fftshift and ifftshift#

fftshift reorders the FFT output so that the zero-frequency component is in the center. These functions keep the unit unchanged.

freqs = u.fft.fftfreq(8, d=0.1)
print('Original order:', freqs)
print('Shifted (zero-centered):', u.fft.fftshift(freqs))
Original order: [ 0.    1.25  2.5   3.75 -5.   -3.75 -2.5  -1.25]
Shifted (zero-centered): [-5.   -3.75 -2.5  -1.25  0.    1.25  2.5   3.75]
# fftshift works on spectra with units too
spec = u.fft.fft(jnp.array([1., 2., 3., 4.]) * u.volt)
print('Spectrum:', spec)
print('Shifted:', u.fft.fftshift(spec))
print('Unit preserved:', u.fft.fftshift(spec).unit)
Spectrum: [10.+0.j -2.+2.j -2.+0.j -2.-2.j] Wb
Shifted: [-2.+0.j -2.-2.j 10.+0.j -2.+2.j] Wb
Unit preserved: Wb

2D FFT: fft2 and ifft2#

The 2D FFT applies the transform along two axes. The unit changes by multiplying with s^2 (one factor of time per transformed dimension).

# A 2D signal (e.g., a small image or spatial field)
field = jnp.array([
    [1., 2., 3., 4.],
    [5., 6., 7., 8.],
    [9., 10., 11., 12.],
    [13., 14., 15., 16.]
]) * u.pascal

spec_2d = u.fft.fft2(field)
print('2D FFT result:')
print(spec_2d)
print('Unit:', spec_2d.unit)  # pascal * s^2
2D FFT result:
[[136. +0.j  -8. +8.j  -8. +0.j  -8. -8.j]
 [-32.+32.j   0. +0.j   0. +0.j   0. +0.j]
 [-32. +0.j   0. +0.j   0. +0.j   0. +0.j]
 [-32.-32.j   0. +0.j   0. +0.j   0. +0.j]] Pa * s^2
Unit: Pa * s^2
# Inverse 2D FFT
recovered_2d = u.fft.ifft2(spec_2d)
print('Recovered 2D signal:')
print(recovered_2d)
print('Unit:', recovered_2d.unit)  # back to pascal
Recovered 2D signal:
[[ 1.+0.j  2.+0.j  3.+0.j  4.+0.j]
 [ 5.+0.j  6.+0.j  7.+0.j  8.+0.j]
 [ 9.+0.j 10.+0.j 11.+0.j 12.+0.j]
 [13.+0.j 14.+0.j 15.+0.j 16.+0.j]] Pa
Unit: Pa

N-D FFT: fftn and ifftn#

Generalization to arbitrary dimensions.

# 3D data
data_3d = jnp.ones((2, 3, 4)) * u.meter
spec_3d = u.fft.fftn(data_3d)
print('3D FFT shape:', spec_3d.shape)
print('3D FFT unit:', spec_3d.unit)  # meter * s^3
3D FFT shape: (2, 3, 4)
3D FFT unit: m * s^3
# Transform along specific axes only
spec_partial = u.fft.fftn(data_3d, axes=(0, 1))  # transform first 2 axes only
print('Partial FFT unit:', spec_partial.unit)  # meter * s^2
Partial FFT unit: m * s^2

Practical Example: Spectral Analysis of a Signal#

Analyze the frequency content of a composite voltage signal.

# Generate a signal: sum of two sinusoids
n = 256
dt = 0.001  # 1 ms sample spacing -> 1000 Hz sampling rate
t = jnp.arange(n) * dt  # time in seconds

# 50 Hz and 120 Hz components
signal_composed = (1.0 * jnp.sin(2 * jnp.pi * 50 * t) +
                   0.5 * jnp.sin(2 * jnp.pi * 120 * t)) * u.volt

print('Signal shape:', signal_composed.shape)
print('Signal unit:', signal_composed.unit)
Signal shape: (256,)
Signal unit: V
# Compute spectrum
spectrum_composed = u.fft.rfft(signal_composed)
freqs_composed = u.fft.rfftfreq(n, d=dt)

# Power spectrum (magnitude squared)
power = u.math.abs(spectrum_composed)
print('Frequency bins:', freqs_composed[:5], '...')
print('Power at DC:', power.mantissa[0])
print('Number of frequency bins:', len(freqs_composed))
Frequency bins: [ 0.         3.9062498  7.8124995 11.718749  15.624999 ] ...
Power at DC: 3.6521716
Number of frequency bins: 129

Summary#

Function

Unit Change

Description

fft(x)

u u*s

Forward 1D FFT

ifft(X)

u*s u

Inverse 1D FFT

rfft(x)

u u*s

Real 1D FFT (positive freq only)

irfft(X)

u*s u

Inverse real 1D FFT

fft2(x)

u u*s^2

Forward 2D FFT

fftn(x)

u u*s^N

Forward N-D FFT

fftfreq(n, d)

dimensionless

Frequency bin values

fftshift(x)

keeps unit

Zero-center the spectrum