JAX Transforms: JIT and vmap with Units#

Colab Open in Kaggle

brainunit Quantity objects work seamlessly with JAX’s core transformations:

  • jax.jit — Just-in-time compilation for faster execution

  • jax.vmap — Automatic vectorization (batching)

  • Composing transforms — Combine jit, vmap, and grad

import brainunit as u
import jax
import jax.numpy as jnp

jax.jit — JIT Compilation#

JIT compilation traces and compiles a function for faster repeated execution. Quantity objects are fully supported — units are tracked through compilation.

# A physics computation
def kinetic_energy(m, v):
    return 0.5 * m * v**2

# JIT-compiled version
jit_ke = jax.jit(kinetic_energy)

m = 5.0 * u.kilogram
v = 10.0 * u.meter / u.second

result = jit_ke(m, v)
print('KE:', result)  # 250 J
KE: 250. J
# Decorator syntax also works
@jax.jit
def coulomb_force(q1, q2, r):
    k = 8.9875e9 * u.newton * u.meter**2 / u.coulomb**2
    return k * q1 * q2 / r**2

q1 = 1.6e-19 * u.coulomb  # electron charge
q2 = 1.6e-19 * u.coulomb
r = 1e-10 * u.meter       # ~1 angstrom

print('Coulomb force:', coulomb_force(q1, q2, r))
Coulomb force: 2.3007996e-08 N

JIT with array computations#

@jax.jit
def rms_voltage(v_samples):
    """Root-mean-square voltage."""
    return u.math.sqrt(u.math.mean(v_samples**2))

samples = jnp.array([1.0, -2.0, 3.0, -1.5, 2.5]) * u.volt
print('RMS voltage:', rms_voltage(samples))
RMS voltage: 2.1213202 V

jax.vmap — Automatic Vectorization#

vmap transforms a function that operates on single values into one that operates on batches, without writing explicit loops.

# Compute kinetic energy for a batch of velocities
m = 2.0 * u.kilogram
velocities = jnp.array([1., 2., 3., 4., 5.]) * u.meter / u.second

# Without vmap: would need a loop
# With vmap: automatic batching
batch_ke = jax.vmap(lambda v: 0.5 * m * v**2)
energies = batch_ke(velocities)
print('Batch KE:', energies)
Batch KE: [ 1.  4.  9. 16. 25.] J
# vmap over vector norms
def vector_norm(v):
    return u.math.sqrt(u.math.sum(v**2))

# Batch of 3D vectors
vectors = jnp.array([
    [1., 0., 0.],
    [0., 1., 0.],
    [3., 4., 0.],
    [1., 1., 1.]
]) * u.meter

norms = jax.vmap(vector_norm)(vectors)
print('Norms:', norms)
Norms: [1.        1.        5.        1.73205078] m

vmap with multiple arguments#

# Ohm's law for multiple resistors: V = I * R
def ohm_law(I, R):
    return I * R

currents = jnp.array([0.1, 0.2, 0.5, 1.0]) * u.ampere
resistances = jnp.array([100., 50., 20., 10.]) * u.ohm

# vmap over both arguments
voltages = jax.vmap(ohm_law)(currents, resistances)
print('Voltages:', voltages)  # all 10V
Voltages: [10. 10. 10. 10.] V
# vmap with in_axes: batch over currents only, same resistance for all
R_fixed = 100.0 * u.ohm
voltages_fixed_R = jax.vmap(ohm_law, in_axes=(0, None))(currents, R_fixed)
print('Voltages (fixed R):', voltages_fixed_R)
Voltages (fixed R): [ 10.  20.  50. 100.] V

vmap for matrix operations#

# Apply a transformation matrix to a batch of vectors
rotation_90 = jnp.array([[0., -1.], [1., 0.]])  # dimensionless rotation matrix

def rotate(v):
    return rotation_90 @ v

points = jnp.array([[1., 0.], [0., 1.], [1., 1.], [2., 3.]]) * u.meter

rotated = jax.vmap(rotate)(points)
print('Original points:')
print(points)
print('Rotated 90 degrees:')
print(rotated)
Original points:
[[1. 0.]
 [0. 1.]
 [1. 1.]
 [2. 3.]] m
Rotated 90 degrees:
[[ 0.  1.]
 [-1.  0.]
 [-1.  1.]
 [-3.  2.]] m

Composing Transforms#

JAX transforms compose naturally. You can combine jit, vmap, and grad.

# jit + vmap: fast batched computation
fast_batch_ke = jax.jit(jax.vmap(lambda v: 0.5 * (2.0 * u.kilogram) * v**2))

vs = jnp.linspace(0., 10., 5) * u.meter / u.second
print('Fast batch KE:', fast_batch_ke(vs))
Fast batch KE: [  0.     6.25  25.    56.25 100.  ] J
# vmap + grad: batch of gradients
def spring_force(x):
    k = 100.0 * u.newton / u.meter
    return -0.5 * k * x**2

# Gradient of spring energy for each position
batch_grad = jax.vmap(u.autograd.grad(spring_force))

positions = jnp.array([0.0, 0.1, 0.2, 0.3, 0.4]) * u.meter
forces = batch_grad(positions)
print('Positions:', positions)
print('Forces:', forces)  # F = -kx
Positions: [0.  0.1 0.2 0.30000001 0.40000001] m
Forces: [ -0.       -10.       -20.       -30.00000191 -40.      ] N
# jit + vmap + grad: maximum performance
fast_batch_grad = jax.jit(jax.vmap(u.autograd.grad(spring_force)))
print('Fast batch forces:', fast_batch_grad(positions))
Fast batch forces: [ -0.       -10.       -20.       -30.00000191 -40.      ] N

Summary#

Transform

Purpose

Example

jax.jit(f)

Compile for speed

jit_f(5.0 * u.meter)

jax.vmap(f)

Automatic batching

vmap(f)(batch_of_quantities)

jax.vmap(f, in_axes=(0, None))

Batch some args

Fixed args use None

jit(vmap(grad(f)))

Compose transforms

Fast batched gradients