Linear Algebra Functions#
brainunit.linalg provides unit-aware linear algebra functions. These functions automatically
track how units propagate through linear algebra operations:
Keeping unit:
trace,diagonal,norm,matrix_transposeChanging unit:
dot,matmul,cross,det,inv,solve,cholesky,kron,matrix_powerRemoving unit:
eig,svd,qr,cond,matrix_rank,slogdet
import brainunit as u
import jax.numpy as jnp
Functions That Keep Unit#
These functions preserve the input unit in the output.
trace — Sum of diagonal elements#
A = jnp.array([[1., 2.], [3., 4.]]) * u.volt
print('A:')
print(A)
print('trace(A):', u.linalg.trace(A)) # 1 + 4 = 5 V
A:
[[1. 2.]
[3. 4.]] V
trace(A): 5. V
diagonal — Extract diagonal elements#
print('diagonal(A):', u.linalg.diagonal(A)) # [1, 4] V
diagonal(A): [1. 4.] V
norm — Vector and matrix norms#
v = jnp.array([3., 4.]) * u.meter
print('v:', v)
print('L2 norm:', u.linalg.norm(v)) # sqrt(9+16) = 5 m
print('L1 norm:', u.linalg.norm(v, ord=1)) # 3+4 = 7 m
print('Inf norm:', u.linalg.norm(v, ord=jnp.inf)) # max(3,4) = 4 m
v: [3. 4.] m
L2 norm: 5. m
L1 norm: 7. m
Inf norm: 4. m
# Matrix norms
print('Frobenius norm:', u.linalg.norm(A)) # sqrt(1+4+9+16) V
print('Matrix L2 norm:', u.linalg.matrix_norm(A, ord=2)) # largest singular value
Frobenius norm: 5.477226 V
Matrix L2 norm: 5.464986 V
matrix_transpose — Transpose#
B = jnp.array([[1., 2., 3.], [4., 5., 6.]]) * u.ampere
print('B shape:', B.shape)
print('B^T shape:', u.linalg.matrix_transpose(B).shape)
print('B^T:')
print(u.linalg.matrix_transpose(B))
B shape: (2, 3)
B^T shape: (3, 2)
B^T:
[[1. 4.]
[2. 5.]
[3. 6.]] A
Functions That Change Unit#
These functions produce outputs with different units than the inputs, following the mathematical rules of the operation.
dot, matmul — Dot and matrix products#
When multiplying quantities, units multiply: [meter] @ [second] → [meter * second]
# Vector dot product
a = jnp.array([1., 2., 3.]) * u.meter
b = jnp.array([4., 5., 6.]) * u.newton
print('a . b:', u.linalg.dot(a, b)) # 1*4 + 2*5 + 3*6 = 32 m*N (= 32 J)
a . b: 32. J
# Matrix-vector multiplication
M = jnp.array([[1., 0.], [0., 2.]]) * u.ohm
i = jnp.array([3., 4.]) * u.ampere
print('M @ i:', u.linalg.matmul(M, i)) # [3, 8] V (Ohm's law: V = IR)
M @ i: [3. 8.] V
# Matrix-matrix multiplication: units multiply
P = jnp.array([[1., 2.], [3., 4.]]) * u.meter
Q = jnp.array([[5., 6.], [7., 8.]]) * u.meter
print('P @ Q:')
print(u.linalg.matmul(P, Q)) # m * m = m^2
P @ Q:
[[19. 22.]
[43. 50.]] m^2
cross — Cross product#
# Force = charge * (velocity x B_field)
velocity = jnp.array([1., 0., 0.]) * u.meter / u.second
b_field = jnp.array([0., 0., 1.]) * u.tesla
print('v x B:', u.linalg.cross(velocity, b_field)) # [0, -1, 0] m*T/s
v x B:
[ 0. -1. 0.] m * T / s
det — Determinant#
For an NxN matrix with unit u, the determinant has unit u^N.
M2 = jnp.array([[2., 1.], [1., 3.]]) * u.meter
print('det(M2):', u.linalg.det(M2)) # 2*3 - 1*1 = 5 m^2
det(M2): 5. m^2
# 3x3 determinant: unit^3
M3 = jnp.array([[1., 0., 2.], [0., 1., 0.], [2., 0., 1.]]) * u.second
print('det(M3):', u.linalg.det(M3)) # s^3
det(M3): -3. s^3
inv — Matrix inverse#
The inverse of a matrix with unit u has unit 1/u.
R = jnp.array([[2., 1.], [1., 3.]]) * u.ohm
print('R:')
print(R)
print('inv(R):')
print(u.linalg.inv(R)) # unit: 1/ohm = siemens
R:
[[2. 1.]
[1. 3.]] ohm
inv(R):
[[ 0.60000002 -0.2]
[-0.2 0.40000001]] S
# Verify: R @ inv(R) should be identity (dimensionless)
R_inv = u.linalg.inv(R)
print('R @ R_inv:')
print(u.linalg.matmul(R, R_inv)) # approximately identity
R @ R_inv:
[[1.0000000e+00 0.0000000e+00]
[1.4901161e-08 1.0000000e+00]]
solve — Solve linear system Ax = b#
If A has unit u_A and b has unit u_b, then x has unit u_b / u_A.
# Physical example: Ohm's law in a circuit network
# R * I = V --> I = solve(R, V)
R_matrix = jnp.array([[10., 2.], [2., 8.]]) * u.ohm
V_vector = jnp.array([12., 6.]) * u.volt
I_solution = u.linalg.solve(R_matrix, V_vector)
print('Current solution:', I_solution) # volt / ohm = ampere
Current solution: [1.10526311 0.47368422] A
cholesky — Cholesky decomposition#
For a matrix with unit u, the Cholesky factor has unit sqrt(u),
since L @ L^T = A and sqrt(u) * sqrt(u) = u.
# Positive definite matrix
S = jnp.array([[4., 2.], [2., 5.]]) * u.meter2
L = u.linalg.cholesky(S)
print('S:')
print(S)
print('cholesky(S):')
print(L) # unit: m (sqrt of m^2)
S:
[[4. 2.]
[2. 5.]] m^2
cholesky(S):
[[2. 0.]
[1. 2.]] m
kron — Kronecker product#
X = jnp.array([[1., 0.], [0., 1.]]) * u.meter
Y = jnp.array([[2., 3.], [4., 5.]]) * u.second
print('kron(X, Y):')
print(u.linalg.kron(X, Y)) # m * s
kron(X, Y):
[[2. 3. 0. 0.]
[4. 5. 0. 0.]
[0. 0. 2. 3.]
[0. 0. 4. 5.]] m * s
matrix_power — Raise matrix to integer power#
T = jnp.array([[1., 1.], [0., 1.]]) * u.meter
print('T^2:', u.linalg.matrix_power(T, 2)) # m^2
print('T^3:', u.linalg.matrix_power(T, 3)) # m^3
T^2: [[1. 2.]
[0. 1.]] m^2
T^3: [[1. 3.]
[0. 1.]] m^3
lstsq — Least-squares solution#
# Fitting y = ax + b where x is in seconds and y in meters
# Design matrix has columns [x, 1]
x_data = jnp.array([0., 1., 2., 3., 4.]) * u.second
y_data = jnp.array([1.1, 2.9, 5.2, 6.8, 9.1]) * u.meter
# Build design matrix (must be same unit or unitless)
A_design = jnp.stack([x_data.mantissa, jnp.ones(5)], axis=1) * u.second
result = u.linalg.lstsq(A_design, y_data)
print('Coefficients:', result[0]) # [slope, intercept] in m/s
Coefficients: [1.99000037 1.03999949] m / s
pinv — Pseudoinverse#
M_rect = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) * u.volt
print('M shape:', M_rect.shape)
print('pinv(M) shape:', u.linalg.pinv(M_rect).shape)
print('pinv(M):')
print(u.linalg.pinv(M_rect)) # unit: 1/V
M shape: (3, 2)
pinv(M) shape: (2, 3)
pinv(M):
[[-1.33333182 -0.3333317 0.6666652 ]
[ 1.08333194 0.33333197 -0.41666541]] 1 / V
Functions That Remove Unit (Decompositions)#
Decompositions separate a matrix into dimensionless factors (orthogonal matrices, singular values, etc.). The singular values / eigenvalues carry the unit information.
svd — Singular Value Decomposition#
M_svd = jnp.array([[1., 2.], [3., 4.]]) * u.meter
U, S, Vt = u.linalg.svd(M_svd)
print('U (dimensionless):', U)
print('S (with unit):', S) # singular values carry the unit
print('Vt (dimensionless):', Vt)
U (dimensionless): [[-0.40455365 -0.91451436]
[-0.9145144 0.4045536 ]]
S (with unit): [5.46498537 0.36596611] m
Vt (dimensionless): [[-0.5760485 -0.81741554]
[ 0.81741554 -0.5760485 ]]
eig / eigh — Eigenvalue decomposition#
# Symmetric matrix (use eigh for better numerical stability)
H = jnp.array([[2., 1.], [1., 3.]]) * u.joule
eigenvalues, eigenvectors = u.linalg.eigh(H)
print('Eigenvalues:', eigenvalues) # J
print('Eigenvectors (dimensionless):')
print(eigenvectors)
Eigenvalues: [1.38196611 3.61803389] J
Eigenvectors (dimensionless):
[[-0.85065085 0.52573115]
[ 0.52573115 0.85065085]]
qr — QR decomposition#
M_qr = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) * u.meter
Q_mat, R_mat = u.linalg.qr(M_qr)
print('Q (dimensionless):')
print(Q_mat)
print('R (with unit):')
print(R_mat) # upper triangular, carries the unit
Q (dimensionless):
[[-0.1690309 0.89708555]
[-0.5070926 0.27602556]
[-0.84515435 -0.34503233]]
R (with unit):
[[-5.91607952 -7.43735886]
[ 0. 0.82808006]] m
cond — Condition number#
# Condition number is always dimensionless (ratio of singular values)
well_cond = jnp.array([[1., 0.], [0., 1.]]) * u.meter
ill_cond = jnp.array([[1., 1.], [1., 1.0001]]) * u.meter
print('Well-conditioned:', u.linalg.cond(well_cond))
print('Ill-conditioned:', u.linalg.cond(ill_cond))
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[24], line 4
1 # Condition number is always dimensionless (ratio of singular values)
2 well_cond = jnp.array([[1., 0.], [0., 1.]]) * u.meter
3 ill_cond = jnp.array([[1., 1.], [1., 1.0001]]) * u.meter
----> 4 print('Well-conditioned:', u.linalg.cond(well_cond))
5 print('Ill-conditioned:', u.linalg.cond(ill_cond))
AttributeError: module 'brainunit.linalg' has no attribute 'cond'
matrix_rank#
full_rank = jnp.array([[1., 2.], [3., 4.]]) * u.meter
rank_def = jnp.array([[1., 2.], [2., 4.]]) * u.meter # row 2 = 2 * row 1
print('Full rank matrix rank:', u.linalg.matrix_rank(full_rank))
print('Rank-deficient matrix rank:', u.linalg.matrix_rank(rank_def))
Full rank matrix rank: 2
Rank-deficient matrix rank: 1
slogdet — Sign and log of determinant#
M_det = jnp.array([[2., 1.], [1., 3.]]) * u.meter
sign, logabsdet = u.linalg.slogdet(M_det)
print('Sign:', sign)
print('Log |det|:', logabsdet) # dimensionless
Sign: 1.0
Log |det|: 1.609438
Practical Example: Solving a Physical System#
A simple resistor network with Kirchhoff’s laws leads to a linear system.
# Three-node resistor network
# Conductance matrix G (siemens) and current sources I (ampere)
G = jnp.array([
[0.3, -0.1, -0.1],
[-0.1, 0.4, -0.2],
[-0.1, -0.2, 0.5]
]) * u.siemens
I_source = jnp.array([1.0, 0.0, -0.5]) * u.ampere
# Solve for node voltages: G @ V = I --> V = solve(G, I)
V_nodes = u.linalg.solve(G, I_source)
print('Node voltages:', V_nodes) # ampere / siemens = volt
# Verify: G @ V should equal I
print('Verification (G @ V):', u.linalg.matmul(G, V_nodes))
Node voltages: [3.71428561 1.00000012 0.14285721] V
Verification (G @ V): [ 9.9999994e-01 3.3101866e-08 -5.0000000e-01] A
Summary#
Function |
Unit Behavior |
Example |
|---|---|---|
|
Keeps input unit |
|
|
Multiplies units |
|
|
Unit^N for NxN |
|
|
Reciprocal unit |
|
|
b_unit / A_unit |
|
|
sqrt(unit) |
|
|
Factors carry units |
|
|
Dimensionless |
Always unitless output |