JAX Transforms: JIT and vmap with Units#
brainunit Quantity objects work seamlessly with JAX’s core transformations:
jax.jit— Just-in-time compilation for faster executionjax.vmap— Automatic vectorization (batching)Composing transforms — Combine jit, vmap, and grad
import brainunit as u
import jax
import jax.numpy as jnp
jax.jit — JIT Compilation#
JIT compilation traces and compiles a function for faster repeated execution.
Quantity objects are fully supported — units are tracked through compilation.
# A physics computation
def kinetic_energy(m, v):
return 0.5 * m * v**2
# JIT-compiled version
jit_ke = jax.jit(kinetic_energy)
m = 5.0 * u.kilogram
v = 10.0 * u.meter / u.second
result = jit_ke(m, v)
print('KE:', result) # 250 J
KE: 250. J
# Decorator syntax also works
@jax.jit
def coulomb_force(q1, q2, r):
k = 8.9875e9 * u.newton * u.meter**2 / u.coulomb**2
return k * q1 * q2 / r**2
q1 = 1.6e-19 * u.coulomb # electron charge
q2 = 1.6e-19 * u.coulomb
r = 1e-10 * u.meter # ~1 angstrom
print('Coulomb force:', coulomb_force(q1, q2, r))
Coulomb force: 2.3007996e-08 N
JIT with array computations#
@jax.jit
def rms_voltage(v_samples):
"""Root-mean-square voltage."""
return u.math.sqrt(u.math.mean(v_samples**2))
samples = jnp.array([1.0, -2.0, 3.0, -1.5, 2.5]) * u.volt
print('RMS voltage:', rms_voltage(samples))
RMS voltage: 2.1213202 V
jax.vmap — Automatic Vectorization#
vmap transforms a function that operates on single values into one that operates
on batches, without writing explicit loops.
# Compute kinetic energy for a batch of velocities
m = 2.0 * u.kilogram
velocities = jnp.array([1., 2., 3., 4., 5.]) * u.meter / u.second
# Without vmap: would need a loop
# With vmap: automatic batching
batch_ke = jax.vmap(lambda v: 0.5 * m * v**2)
energies = batch_ke(velocities)
print('Batch KE:', energies)
Batch KE: [ 1. 4. 9. 16. 25.] J
# vmap over vector norms
def vector_norm(v):
return u.math.sqrt(u.math.sum(v**2))
# Batch of 3D vectors
vectors = jnp.array([
[1., 0., 0.],
[0., 1., 0.],
[3., 4., 0.],
[1., 1., 1.]
]) * u.meter
norms = jax.vmap(vector_norm)(vectors)
print('Norms:', norms)
Norms: [1. 1. 5. 1.73205078] m
vmap with multiple arguments#
# Ohm's law for multiple resistors: V = I * R
def ohm_law(I, R):
return I * R
currents = jnp.array([0.1, 0.2, 0.5, 1.0]) * u.ampere
resistances = jnp.array([100., 50., 20., 10.]) * u.ohm
# vmap over both arguments
voltages = jax.vmap(ohm_law)(currents, resistances)
print('Voltages:', voltages) # all 10V
Voltages: [10. 10. 10. 10.] V
# vmap with in_axes: batch over currents only, same resistance for all
R_fixed = 100.0 * u.ohm
voltages_fixed_R = jax.vmap(ohm_law, in_axes=(0, None))(currents, R_fixed)
print('Voltages (fixed R):', voltages_fixed_R)
Voltages (fixed R): [ 10. 20. 50. 100.] V
vmap for matrix operations#
# Apply a transformation matrix to a batch of vectors
rotation_90 = jnp.array([[0., -1.], [1., 0.]]) # dimensionless rotation matrix
def rotate(v):
return rotation_90 @ v
points = jnp.array([[1., 0.], [0., 1.], [1., 1.], [2., 3.]]) * u.meter
rotated = jax.vmap(rotate)(points)
print('Original points:')
print(points)
print('Rotated 90 degrees:')
print(rotated)
Original points:
[[1. 0.]
[0. 1.]
[1. 1.]
[2. 3.]] m
Rotated 90 degrees:
[[ 0. 1.]
[-1. 0.]
[-1. 1.]
[-3. 2.]] m
Composing Transforms#
JAX transforms compose naturally. You can combine jit, vmap, and grad.
# jit + vmap: fast batched computation
fast_batch_ke = jax.jit(jax.vmap(lambda v: 0.5 * (2.0 * u.kilogram) * v**2))
vs = jnp.linspace(0., 10., 5) * u.meter / u.second
print('Fast batch KE:', fast_batch_ke(vs))
Fast batch KE: [ 0. 6.25 25. 56.25 100. ] J
# vmap + grad: batch of gradients
def spring_force(x):
k = 100.0 * u.newton / u.meter
return -0.5 * k * x**2
# Gradient of spring energy for each position
batch_grad = jax.vmap(u.autograd.grad(spring_force))
positions = jnp.array([0.0, 0.1, 0.2, 0.3, 0.4]) * u.meter
forces = batch_grad(positions)
print('Positions:', positions)
print('Forces:', forces) # F = -kx
Positions: [0. 0.1 0.2 0.30000001 0.40000001] m
Forces: [ -0. -10. -20. -30.00000191 -40. ] N
# jit + vmap + grad: maximum performance
fast_batch_grad = jax.jit(jax.vmap(u.autograd.grad(spring_force)))
print('Fast batch forces:', fast_batch_grad(positions))
Fast batch forces: [ -0. -10. -20. -30.00000191 -40. ] N
Summary#
Transform |
Purpose |
Example |
|---|---|---|
|
Compile for speed |
|
|
Automatic batching |
|
|
Batch some args |
Fixed args use |
|
Compose transforms |
Fast batched gradients |