Fourier Transform Functions#
brainunit.fft provides unit-aware Fast Fourier Transform functions.
The FFT changes units because it involves an implicit integration over the transform variable:
Changing unit:
fft,ifft,rfft,irfft,fft2,ifft2,fftn,ifftn,rfft2,irfft2,rfftn,irfftn,fftfreq,rfftfreqKeeping unit:
fftshift,ifftshift
import brainunit as u
import jax.numpy as jnp
1D FFT: fft and ifft#
The forward FFT of a signal with unit u produces a spectrum with unit u * s (multiplied by the sample-spacing unit).
The inverse FFT reverses the operation, recovering the original unit.
# A simple voltage signal
signal = jnp.array([1., 2., 3., 4., 3., 2., 1., 0.]) * u.volt
print('Signal:', signal)
print('Signal unit:', signal.unit)
Signal: [1. 2. 3. 4. 3. 2. 1. 0.] V
Signal unit: V
# Forward FFT
spectrum = u.fft.fft(signal)
print('Spectrum:', spectrum)
print('Spectrum unit:', spectrum.unit) # volt * second
Spectrum: [16. +0.j -4.82842731-4.82842731j 0. +0.j
0.82842708-0.82842708j 0. +0.j 0.82842708+0.82842708j
0. +0.j -4.82842731+4.82842731j] Wb
Spectrum unit: Wb
# Inverse FFT recovers original signal and unit
recovered = u.fft.ifft(spectrum)
print('Recovered:', recovered)
print('Recovered unit:', recovered.unit) # back to volt
Recovered: [1.+0.j 2.+0.j 3.+0.j 4.+0.j 3.+0.j 2.+0.j 1.+0.j 0.+0.j] V
Recovered unit: V
Real FFT: rfft and irfft#
For real-valued signals, rfft computes only the positive-frequency half of the spectrum
(since the negative frequencies are conjugate symmetric).
signal_real = jnp.array([1., 0., -1., 0., 1., 0., -1., 0.]) * u.ampere
# rfft returns only positive frequencies (N//2 + 1 components)
spec_real = u.fft.rfft(signal_real)
print('rfft result:', spec_real)
print('Length:', len(spec_real.mantissa), '(vs', len(signal_real.mantissa), 'input samples)')
rfft result:
[0.+0.j 0.+0.j 4.-0.j 0.+0.j 0.+0.j] C
Length: 5 (vs 8 input samples)
# irfft recovers the original signal
recovered_real = u.fft.irfft(spec_real)
print('Recovered:', recovered_real)
Recovered: [ 1. 0. -1. 0. 1. 0. -1. 0.] A
Frequency Axes: fftfreq and rfftfreq#
These functions generate the frequency bin values corresponding to the FFT output.
The d parameter is the sample spacing.
n_samples = 8
sample_spacing = 0.001 # 1 ms between samples (1000 Hz sampling rate)
freqs = u.fft.fftfreq(n_samples, d=sample_spacing)
print('FFT frequencies (Hz):', freqs)
rfreqs = u.fft.rfftfreq(n_samples, d=sample_spacing)
print('Real FFT frequencies (Hz):', rfreqs)
FFT frequencies (Hz): [ 0. 124.99999 249.99998 374.99997 -499.99997 -374.99997
-249.99998 -124.99999]
Real FFT frequencies (Hz): [ 0. 124.99999 249.99998 374.99997 499.99997]
Shifting: fftshift and ifftshift#
fftshift reorders the FFT output so that the zero-frequency component is in the center.
These functions keep the unit unchanged.
freqs = u.fft.fftfreq(8, d=0.1)
print('Original order:', freqs)
print('Shifted (zero-centered):', u.fft.fftshift(freqs))
Original order: [ 0. 1.25 2.5 3.75 -5. -3.75 -2.5 -1.25]
Shifted (zero-centered): [-5. -3.75 -2.5 -1.25 0. 1.25 2.5 3.75]
# fftshift works on spectra with units too
spec = u.fft.fft(jnp.array([1., 2., 3., 4.]) * u.volt)
print('Spectrum:', spec)
print('Shifted:', u.fft.fftshift(spec))
print('Unit preserved:', u.fft.fftshift(spec).unit)
Spectrum: [10.+0.j -2.+2.j -2.+0.j -2.-2.j] Wb
Shifted: [-2.+0.j -2.-2.j 10.+0.j -2.+2.j] Wb
Unit preserved: Wb
2D FFT: fft2 and ifft2#
The 2D FFT applies the transform along two axes. The unit changes by multiplying
with s^2 (one factor of time per transformed dimension).
# A 2D signal (e.g., a small image or spatial field)
field = jnp.array([
[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 10., 11., 12.],
[13., 14., 15., 16.]
]) * u.pascal
spec_2d = u.fft.fft2(field)
print('2D FFT result:')
print(spec_2d)
print('Unit:', spec_2d.unit) # pascal * s^2
2D FFT result:
[[136. +0.j -8. +8.j -8. +0.j -8. -8.j]
[-32.+32.j 0. +0.j 0. +0.j 0. +0.j]
[-32. +0.j 0. +0.j 0. +0.j 0. +0.j]
[-32.-32.j 0. +0.j 0. +0.j 0. +0.j]] Pa * s^2
Unit: Pa * s^2
# Inverse 2D FFT
recovered_2d = u.fft.ifft2(spec_2d)
print('Recovered 2D signal:')
print(recovered_2d)
print('Unit:', recovered_2d.unit) # back to pascal
Recovered 2D signal:
[[ 1.+0.j 2.+0.j 3.+0.j 4.+0.j]
[ 5.+0.j 6.+0.j 7.+0.j 8.+0.j]
[ 9.+0.j 10.+0.j 11.+0.j 12.+0.j]
[13.+0.j 14.+0.j 15.+0.j 16.+0.j]] Pa
Unit: Pa
N-D FFT: fftn and ifftn#
Generalization to arbitrary dimensions.
# 3D data
data_3d = jnp.ones((2, 3, 4)) * u.meter
spec_3d = u.fft.fftn(data_3d)
print('3D FFT shape:', spec_3d.shape)
print('3D FFT unit:', spec_3d.unit) # meter * s^3
3D FFT shape: (2, 3, 4)
3D FFT unit: m * s^3
# Transform along specific axes only
spec_partial = u.fft.fftn(data_3d, axes=(0, 1)) # transform first 2 axes only
print('Partial FFT unit:', spec_partial.unit) # meter * s^2
Partial FFT unit: m * s^2
Practical Example: Spectral Analysis of a Signal#
Analyze the frequency content of a composite voltage signal.
# Generate a signal: sum of two sinusoids
n = 256
dt = 0.001 # 1 ms sample spacing -> 1000 Hz sampling rate
t = jnp.arange(n) * dt # time in seconds
# 50 Hz and 120 Hz components
signal_composed = (1.0 * jnp.sin(2 * jnp.pi * 50 * t) +
0.5 * jnp.sin(2 * jnp.pi * 120 * t)) * u.volt
print('Signal shape:', signal_composed.shape)
print('Signal unit:', signal_composed.unit)
Signal shape: (256,)
Signal unit: V
# Compute spectrum
spectrum_composed = u.fft.rfft(signal_composed)
freqs_composed = u.fft.rfftfreq(n, d=dt)
# Power spectrum (magnitude squared)
power = u.math.abs(spectrum_composed)
print('Frequency bins:', freqs_composed[:5], '...')
print('Power at DC:', power.mantissa[0])
print('Number of frequency bins:', len(freqs_composed))
Frequency bins: [ 0. 3.9062498 7.8124995 11.718749 15.624999 ] ...
Power at DC: 3.6521716
Number of frequency bins: 129
Summary#
Function |
Unit Change |
Description |
|---|---|---|
|
|
Forward 1D FFT |
|
|
Inverse 1D FFT |
|
|
Real 1D FFT (positive freq only) |
|
|
Inverse real 1D FFT |
|
|
Forward 2D FFT |
|
|
Forward N-D FFT |
|
dimensionless |
Frequency bin values |
|
keeps unit |
Zero-center the spectrum |