Automatic Differentiation with Units#

Colab Open in Kaggle

brainunit.autograd provides unit-aware automatic differentiation. When you differentiate a function that operates on physical quantities, the derivatives automatically carry the correct units.

Available functions:

  • grad — Gradient (first derivative)

  • value_and_grad — Compute value and gradient simultaneously

  • jacobian / jacrev / jacfwd — Jacobian matrix

  • hessian — Hessian matrix (second derivatives)

  • vector_grad — Element-wise gradient for vector-valued functions

import brainunit as u
import jax.numpy as jnp

grad — First Derivative#

The gradient of a function f(x) with respect to x has unit unit(f) / unit(x).

For example, if f(v) = 0.5 * m * v^2 (kinetic energy in Joules) and v is in m/s, then df/dv has unit J / (m/s) = kg * m/s (momentum).

# Kinetic energy: KE = 0.5 * m * v^2
mass = 2.0 * u.kilogram

def kinetic_energy(v):
    return 0.5 * mass * v**2

v = 3.0 * u.meter / u.second
print('KE:', kinetic_energy(v))
KE: 9. J
# dKE/dv = m * v  (momentum)
dKE_dv = u.autograd.grad(kinetic_energy)
print('dKE/dv:', dKE_dv(v))  # 2 * 3 = 6 kg*m/s
dKE/dv: 6. kg * m / s
# Gravitational potential energy: PE = m * g * h
g = 9.81 * u.meter / u.second**2

def potential_energy(h):
    return mass * g * h

# dPE/dh = m * g (force, in Newtons)
dPE_dh = u.autograd.grad(potential_energy)
print('dPE/dh:', dPE_dh(10.0 * u.meter))  # 2 * 9.81 = 19.62 N
dPE/dh: 19.62 N

Differentiating with respect to specific arguments#

Use argnums to select which argument to differentiate with respect to.

def power_dissipated(V, R):
    """P = V^2 / R"""
    return V**2 / R

V = 12.0 * u.volt
R = 4.0 * u.ohm

# dP/dV = 2V/R
dP_dV = u.autograd.grad(power_dissipated, argnums=0)
print('dP/dV:', dP_dV(V, R))  # 2*12/4 = 6 V/ohm = 6 A (watt/volt)

# dP/dR = -V^2/R^2
dP_dR = u.autograd.grad(power_dissipated, argnums=1)
print('dP/dR:', dP_dR(V, R))  # -144/16 = -9 V^2/ohm^2 = -9 W/ohm
dP/dV: 6. A
dP/dR: -9. A^2

value_and_grad — Compute Both at Once#

When you need both the function value and its gradient, value_and_grad is more efficient than calling them separately.

val_grad_fn = u.autograd.value_and_grad(kinetic_energy)

value, gradient = val_grad_fn(3.0 * u.meter / u.second)
print('Value (KE):', value)     # 9.0 J
print('Gradient (momentum):', gradient)  # 6.0 kg*m/s
Value (KE): 9. J
Gradient (momentum): 6. kg * m / s

With auxiliary outputs#

Use has_aux=True when the function returns extra outputs alongside the scalar to differentiate.

def energy_with_info(v):
    ke = 0.5 * mass * v**2
    momentum = mass * v
    return ke, momentum  # (scalar to diff, auxiliary)

val_grad_aux = u.autograd.value_and_grad(energy_with_info, has_aux=True)
(ke, momentum), grad = val_grad_aux(3.0 * u.meter / u.second)

print('KE:', ke)
print('Momentum (aux):', momentum)
print('Gradient:', grad)
KE: 9. J
Momentum (aux): 6. kg * m / s
Gradient: 6. kg * m / s

jacobian — Jacobian Matrix#

For a function f: R^n -> R^m, the Jacobian is an m x n matrix where J[i,j] = df_i/dx_j. The unit of each entry is unit(f_i) / unit(x_j).

# Simple scalar-to-scalar: Jacobian reduces to a scalar (the derivative)
def f(x):
    return x**3

J = u.autograd.jacobian(f)(2.0 * u.meter)
print('Jacobian of x^3 at x=2m:', J)  # 3 * 4 = 12 m^2
Jacobian of x^3 at x=2m: 12. m^2
# Forward-mode Jacobian (more efficient when input dim < output dim)
J_fwd = u.autograd.jacfwd(f)(2.0 * u.meter)
print('jacfwd:', J_fwd)
jacfwd: 12. m^2
# Reverse-mode Jacobian (more efficient when output dim < input dim)
J_rev = u.autograd.jacrev(f)(2.0 * u.meter)
print('jacrev:', J_rev)
jacrev: 12. m^2

hessian — Second Derivatives#

The Hessian matrix contains all second partial derivatives: H[i,j] = d^2f / (dx_i * dx_j).

Unit: unit(f) / (unit(x))^2.

# f(x) = x^3  ->  f''(x) = 6x
H = u.autograd.hessian(f)(2.0 * u.meter)
print('Hessian of x^3 at x=2m:', H)  # 6*2 = 12 m
Hessian of x^3 at x=2m: 12. m
# Quadratic function: f(x) = 0.5 * k * x^2  (spring potential)
k = 100.0 * u.newton / u.meter  # spring constant

def spring_energy(x):
    return 0.5 * k * x**2

# First derivative: dU/dx = k*x (force)
print('Force:', u.autograd.grad(spring_energy)(0.1 * u.meter))

# Hessian (second derivative): d^2U/dx^2 = k (stiffness)
print('Stiffness:', u.autograd.hessian(spring_energy)(0.1 * u.meter))
Force: 10. N
Stiffness: 100. J / m^2

vector_grad — Element-wise Gradient#

For a function that maps a vector to a vector element-wise, vector_grad computes the gradient for each element independently.

def square(x):
    return x**2

x = jnp.array([1., 2., 3., 4.]) * u.meter
vg = u.autograd.vector_grad(square)(x)
print('vector_grad of x^2:', vg)  # [2, 4, 6, 8] m
vector_grad of x^2: [2. 4. 6. 8.] m
# With return_value=True, get both the gradient and the function value
vg, val = u.autograd.vector_grad(square, return_value=True)(x)
print('Gradients:', vg)
print('Values:', val)
Gradients: [2. 4. 6. 8.] m
Values: [ 1.  4.  9. 16.] m^2

Practical Example: Optimization with Units#

Gradient descent on a physical objective function.

# Find the position where total energy is minimized
# E(x) = 0.5*k*x^2 + m*g*x  (spring + gravity along x-axis)
k_spring = 50.0 * u.newton / u.meter
m_obj = 1.0 * u.kilogram
g_acc = 9.81 * u.meter / u.second**2

def total_energy(x):
    return 0.5 * k_spring * x**2 + m_obj * g_acc * x

grad_E = u.autograd.grad(total_energy)

# Simple gradient descent
x = 0.0 * u.meter
lr = 0.01 * u.meter / u.newton  # learning rate with correct units

for i in range(100):
    x = x - lr * grad_E(x)

print('Equilibrium position:', x)  # should be ~ -m*g/k = -0.1962 m
print('Analytical solution:', -m_obj * g_acc / k_spring)
Equilibrium position: -0.19620001 m
Analytical solution: -0.1962 m

Summary#

Function

Description

Output Unit

grad(f)

First derivative

unit(f) / unit(x)

value_and_grad(f)

Value + first derivative

(unit(f), unit(f)/unit(x))

jacobian(f)

Jacobian matrix

unit(f_i) / unit(x_j)

hessian(f)

Hessian matrix

unit(f) / unit(x)^2

vector_grad(f)

Element-wise gradient

unit(f) / unit(x)