Automatic Differentiation with Units#
brainunit.autograd provides unit-aware automatic differentiation. When you differentiate a
function that operates on physical quantities, the derivatives automatically carry the correct
units.
Available functions:
grad— Gradient (first derivative)value_and_grad— Compute value and gradient simultaneouslyjacobian/jacrev/jacfwd— Jacobian matrixhessian— Hessian matrix (second derivatives)vector_grad— Element-wise gradient for vector-valued functions
import brainunit as u
import jax.numpy as jnp
grad — First Derivative#
The gradient of a function f(x) with respect to x has unit unit(f) / unit(x).
For example, if f(v) = 0.5 * m * v^2 (kinetic energy in Joules) and v is in m/s,
then df/dv has unit J / (m/s) = kg * m/s (momentum).
# Kinetic energy: KE = 0.5 * m * v^2
mass = 2.0 * u.kilogram
def kinetic_energy(v):
return 0.5 * mass * v**2
v = 3.0 * u.meter / u.second
print('KE:', kinetic_energy(v))
KE: 9. J
# dKE/dv = m * v (momentum)
dKE_dv = u.autograd.grad(kinetic_energy)
print('dKE/dv:', dKE_dv(v)) # 2 * 3 = 6 kg*m/s
dKE/dv: 6. kg * m / s
# Gravitational potential energy: PE = m * g * h
g = 9.81 * u.meter / u.second**2
def potential_energy(h):
return mass * g * h
# dPE/dh = m * g (force, in Newtons)
dPE_dh = u.autograd.grad(potential_energy)
print('dPE/dh:', dPE_dh(10.0 * u.meter)) # 2 * 9.81 = 19.62 N
dPE/dh: 19.62 N
Differentiating with respect to specific arguments#
Use argnums to select which argument to differentiate with respect to.
def power_dissipated(V, R):
"""P = V^2 / R"""
return V**2 / R
V = 12.0 * u.volt
R = 4.0 * u.ohm
# dP/dV = 2V/R
dP_dV = u.autograd.grad(power_dissipated, argnums=0)
print('dP/dV:', dP_dV(V, R)) # 2*12/4 = 6 V/ohm = 6 A (watt/volt)
# dP/dR = -V^2/R^2
dP_dR = u.autograd.grad(power_dissipated, argnums=1)
print('dP/dR:', dP_dR(V, R)) # -144/16 = -9 V^2/ohm^2 = -9 W/ohm
dP/dV: 6. A
dP/dR: -9. A^2
value_and_grad — Compute Both at Once#
When you need both the function value and its gradient, value_and_grad is more
efficient than calling them separately.
val_grad_fn = u.autograd.value_and_grad(kinetic_energy)
value, gradient = val_grad_fn(3.0 * u.meter / u.second)
print('Value (KE):', value) # 9.0 J
print('Gradient (momentum):', gradient) # 6.0 kg*m/s
Value (KE): 9. J
Gradient (momentum): 6. kg * m / s
With auxiliary outputs#
Use has_aux=True when the function returns extra outputs alongside the scalar to differentiate.
def energy_with_info(v):
ke = 0.5 * mass * v**2
momentum = mass * v
return ke, momentum # (scalar to diff, auxiliary)
val_grad_aux = u.autograd.value_and_grad(energy_with_info, has_aux=True)
(ke, momentum), grad = val_grad_aux(3.0 * u.meter / u.second)
print('KE:', ke)
print('Momentum (aux):', momentum)
print('Gradient:', grad)
KE: 9. J
Momentum (aux): 6. kg * m / s
Gradient: 6. kg * m / s
jacobian — Jacobian Matrix#
For a function f: R^n -> R^m, the Jacobian is an m x n matrix where
J[i,j] = df_i/dx_j. The unit of each entry is unit(f_i) / unit(x_j).
# Simple scalar-to-scalar: Jacobian reduces to a scalar (the derivative)
def f(x):
return x**3
J = u.autograd.jacobian(f)(2.0 * u.meter)
print('Jacobian of x^3 at x=2m:', J) # 3 * 4 = 12 m^2
Jacobian of x^3 at x=2m: 12. m^2
# Forward-mode Jacobian (more efficient when input dim < output dim)
J_fwd = u.autograd.jacfwd(f)(2.0 * u.meter)
print('jacfwd:', J_fwd)
jacfwd: 12. m^2
# Reverse-mode Jacobian (more efficient when output dim < input dim)
J_rev = u.autograd.jacrev(f)(2.0 * u.meter)
print('jacrev:', J_rev)
jacrev: 12. m^2
hessian — Second Derivatives#
The Hessian matrix contains all second partial derivatives:
H[i,j] = d^2f / (dx_i * dx_j).
Unit: unit(f) / (unit(x))^2.
# f(x) = x^3 -> f''(x) = 6x
H = u.autograd.hessian(f)(2.0 * u.meter)
print('Hessian of x^3 at x=2m:', H) # 6*2 = 12 m
Hessian of x^3 at x=2m: 12. m
# Quadratic function: f(x) = 0.5 * k * x^2 (spring potential)
k = 100.0 * u.newton / u.meter # spring constant
def spring_energy(x):
return 0.5 * k * x**2
# First derivative: dU/dx = k*x (force)
print('Force:', u.autograd.grad(spring_energy)(0.1 * u.meter))
# Hessian (second derivative): d^2U/dx^2 = k (stiffness)
print('Stiffness:', u.autograd.hessian(spring_energy)(0.1 * u.meter))
Force: 10. N
Stiffness: 100. J / m^2
vector_grad — Element-wise Gradient#
For a function that maps a vector to a vector element-wise, vector_grad computes
the gradient for each element independently.
def square(x):
return x**2
x = jnp.array([1., 2., 3., 4.]) * u.meter
vg = u.autograd.vector_grad(square)(x)
print('vector_grad of x^2:', vg) # [2, 4, 6, 8] m
vector_grad of x^2: [2. 4. 6. 8.] m
# With return_value=True, get both the gradient and the function value
vg, val = u.autograd.vector_grad(square, return_value=True)(x)
print('Gradients:', vg)
print('Values:', val)
Gradients: [2. 4. 6. 8.] m
Values: [ 1. 4. 9. 16.] m^2
Practical Example: Optimization with Units#
Gradient descent on a physical objective function.
# Find the position where total energy is minimized
# E(x) = 0.5*k*x^2 + m*g*x (spring + gravity along x-axis)
k_spring = 50.0 * u.newton / u.meter
m_obj = 1.0 * u.kilogram
g_acc = 9.81 * u.meter / u.second**2
def total_energy(x):
return 0.5 * k_spring * x**2 + m_obj * g_acc * x
grad_E = u.autograd.grad(total_energy)
# Simple gradient descent
x = 0.0 * u.meter
lr = 0.01 * u.meter / u.newton # learning rate with correct units
for i in range(100):
x = x - lr * grad_E(x)
print('Equilibrium position:', x) # should be ~ -m*g/k = -0.1962 m
print('Analytical solution:', -m_obj * g_acc / k_spring)
Equilibrium position: -0.19620001 m
Analytical solution: -0.1962 m
Summary#
Function |
Description |
Output Unit |
|---|---|---|
|
First derivative |
|
|
Value + first derivative |
|
|
Jacobian matrix |
|
|
Hessian matrix |
|
|
Element-wise gradient |
|