Linear Algebra Functions#

Colab Open in Kaggle

brainunit.linalg provides unit-aware linear algebra functions. These functions automatically track how units propagate through linear algebra operations:

  • Keeping unit: trace, diagonal, norm, matrix_transpose

  • Changing unit: dot, matmul, cross, det, inv, solve, cholesky, kron, matrix_power

  • Removing unit: eig, svd, qr, cond, matrix_rank, slogdet

import brainunit as u
import jax.numpy as jnp

Functions That Keep Unit#

These functions preserve the input unit in the output.

trace — Sum of diagonal elements#

A = jnp.array([[1., 2.], [3., 4.]]) * u.volt
print('A:')
print(A)
print('trace(A):', u.linalg.trace(A))  # 1 + 4 = 5 V
A:
[[1. 2.]
 [3. 4.]] V
trace(A): 5. V

diagonal — Extract diagonal elements#

print('diagonal(A):', u.linalg.diagonal(A))  # [1, 4] V
diagonal(A): [1. 4.] V

norm — Vector and matrix norms#

v = jnp.array([3., 4.]) * u.meter
print('v:', v)
print('L2 norm:', u.linalg.norm(v))          # sqrt(9+16) = 5 m
print('L1 norm:', u.linalg.norm(v, ord=1))   # 3+4 = 7 m
print('Inf norm:', u.linalg.norm(v, ord=jnp.inf))  # max(3,4) = 4 m
v: [3. 4.] m
L2 norm: 5. m
L1 norm: 7. m
Inf norm: 4. m
# Matrix norms
print('Frobenius norm:', u.linalg.norm(A))               # sqrt(1+4+9+16) V
print('Matrix L2 norm:', u.linalg.matrix_norm(A, ord=2)) # largest singular value
Frobenius norm: 5.477226 V
Matrix L2 norm: 5.464986 V

matrix_transpose — Transpose#

B = jnp.array([[1., 2., 3.], [4., 5., 6.]]) * u.ampere
print('B shape:', B.shape)
print('B^T shape:', u.linalg.matrix_transpose(B).shape)
print('B^T:')
print(u.linalg.matrix_transpose(B))
B shape: (2, 3)
B^T shape: (3, 2)
B^T:
[[1. 4.]
 [2. 5.]
 [3. 6.]] A

Functions That Change Unit#

These functions produce outputs with different units than the inputs, following the mathematical rules of the operation.

dot, matmul — Dot and matrix products#

When multiplying quantities, units multiply: [meter] @ [second] [meter * second]

# Vector dot product
a = jnp.array([1., 2., 3.]) * u.meter
b = jnp.array([4., 5., 6.]) * u.newton
print('a . b:', u.linalg.dot(a, b))  # 1*4 + 2*5 + 3*6 = 32 m*N (= 32 J)
a . b: 32. J
# Matrix-vector multiplication
M = jnp.array([[1., 0.], [0., 2.]]) * u.ohm
i = jnp.array([3., 4.]) * u.ampere
print('M @ i:', u.linalg.matmul(M, i))  # [3, 8] V  (Ohm's law: V = IR)
M @ i: [3. 8.] V
# Matrix-matrix multiplication: units multiply
P = jnp.array([[1., 2.], [3., 4.]]) * u.meter
Q = jnp.array([[5., 6.], [7., 8.]]) * u.meter
print('P @ Q:')
print(u.linalg.matmul(P, Q))  # m * m = m^2
P @ Q:
[[19. 22.]
 [43. 50.]] m^2

cross — Cross product#

# Force = charge * (velocity x B_field)
velocity = jnp.array([1., 0., 0.]) * u.meter / u.second
b_field = jnp.array([0., 0., 1.]) * u.tesla
print('v x B:', u.linalg.cross(velocity, b_field))  # [0, -1, 0] m*T/s
v x B: 
[ 0. -1.  0.] m * T / s

det — Determinant#

For an NxN matrix with unit u, the determinant has unit u^N.

M2 = jnp.array([[2., 1.], [1., 3.]]) * u.meter
print('det(M2):', u.linalg.det(M2))  # 2*3 - 1*1 = 5 m^2
det(M2): 5. m^2
# 3x3 determinant: unit^3
M3 = jnp.array([[1., 0., 2.], [0., 1., 0.], [2., 0., 1.]]) * u.second
print('det(M3):', u.linalg.det(M3))  # s^3
det(M3): -3. s^3

inv — Matrix inverse#

The inverse of a matrix with unit u has unit 1/u.

R = jnp.array([[2., 1.], [1., 3.]]) * u.ohm
print('R:')
print(R)
print('inv(R):')
print(u.linalg.inv(R))  # unit: 1/ohm = siemens
R:
[[2. 1.]
 [1. 3.]] ohm
inv(R):
[[ 0.60000002 -0.2]
 [-0.2  0.40000001]] S
# Verify: R @ inv(R) should be identity (dimensionless)
R_inv = u.linalg.inv(R)
print('R @ R_inv:')
print(u.linalg.matmul(R, R_inv))  # approximately identity
R @ R_inv:
[[1.0000000e+00 0.0000000e+00]
 [1.4901161e-08 1.0000000e+00]]

solve — Solve linear system Ax = b#

If A has unit u_A and b has unit u_b, then x has unit u_b / u_A.

# Physical example: Ohm's law in a circuit network
# R * I = V  -->  I = solve(R, V)
R_matrix = jnp.array([[10., 2.], [2., 8.]]) * u.ohm
V_vector = jnp.array([12., 6.]) * u.volt

I_solution = u.linalg.solve(R_matrix, V_vector)
print('Current solution:', I_solution)  # volt / ohm = ampere
Current solution: [1.10526311 0.47368422] A

cholesky — Cholesky decomposition#

For a matrix with unit u, the Cholesky factor has unit sqrt(u), since L @ L^T = A and sqrt(u) * sqrt(u) = u.

# Positive definite matrix
S = jnp.array([[4., 2.], [2., 5.]]) * u.meter2
L = u.linalg.cholesky(S)
print('S:')
print(S)
print('cholesky(S):')
print(L)  # unit: m (sqrt of m^2)
S:
[[4. 2.]
 [2. 5.]] m^2
cholesky(S):
[[2. 0.]
 [1. 2.]] m

kron — Kronecker product#

X = jnp.array([[1., 0.], [0., 1.]]) * u.meter
Y = jnp.array([[2., 3.], [4., 5.]]) * u.second
print('kron(X, Y):')
print(u.linalg.kron(X, Y))  # m * s
kron(X, Y):
[[2. 3. 0. 0.]
 [4. 5. 0. 0.]
 [0. 0. 2. 3.]
 [0. 0. 4. 5.]] m * s

matrix_power — Raise matrix to integer power#

T = jnp.array([[1., 1.], [0., 1.]]) * u.meter
print('T^2:', u.linalg.matrix_power(T, 2))   # m^2
print('T^3:', u.linalg.matrix_power(T, 3))   # m^3
T^2: [[1. 2.]
 [0. 1.]] m^2
T^3: [[1. 3.]
 [0. 1.]] m^3

lstsq — Least-squares solution#

# Fitting y = ax + b  where x is in seconds and y in meters
# Design matrix has columns [x, 1]
x_data = jnp.array([0., 1., 2., 3., 4.]) * u.second
y_data = jnp.array([1.1, 2.9, 5.2, 6.8, 9.1]) * u.meter

# Build design matrix (must be same unit or unitless)
A_design = jnp.stack([x_data.mantissa, jnp.ones(5)], axis=1) * u.second
result = u.linalg.lstsq(A_design, y_data)
print('Coefficients:', result[0])  # [slope, intercept] in m/s
Coefficients: [1.99000037 1.03999949] m / s

pinv — Pseudoinverse#

M_rect = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) * u.volt
print('M shape:', M_rect.shape)
print('pinv(M) shape:', u.linalg.pinv(M_rect).shape)
print('pinv(M):')
print(u.linalg.pinv(M_rect))  # unit: 1/V
M shape: (3, 2)
pinv(M) shape: (2, 3)
pinv(M):
[[-1.33333182 -0.3333317   0.6666652 ]
 [ 1.08333194  0.33333197 -0.41666541]] 1 / V

Functions That Remove Unit (Decompositions)#

Decompositions separate a matrix into dimensionless factors (orthogonal matrices, singular values, etc.). The singular values / eigenvalues carry the unit information.

svd — Singular Value Decomposition#

M_svd = jnp.array([[1., 2.], [3., 4.]]) * u.meter
U, S, Vt = u.linalg.svd(M_svd)
print('U (dimensionless):', U)  
print('S (with unit):', S)  # singular values carry the unit
print('Vt (dimensionless):', Vt)
U (dimensionless): [[-0.40455365 -0.91451436]
 [-0.9145144   0.4045536 ]]
S (with unit): [5.46498537 0.36596611] m
Vt (dimensionless): [[-0.5760485  -0.81741554]
 [ 0.81741554 -0.5760485 ]]

eig / eigh — Eigenvalue decomposition#

# Symmetric matrix (use eigh for better numerical stability)
H = jnp.array([[2., 1.], [1., 3.]]) * u.joule
eigenvalues, eigenvectors = u.linalg.eigh(H)
print('Eigenvalues:', eigenvalues)    # J
print('Eigenvectors (dimensionless):')
print(eigenvectors)
Eigenvalues: [1.38196611 3.61803389] J
Eigenvectors (dimensionless):
[[-0.85065085  0.52573115]
 [ 0.52573115  0.85065085]]

qr — QR decomposition#

M_qr = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) * u.meter
Q_mat, R_mat = u.linalg.qr(M_qr)
print('Q (dimensionless):')
print(Q_mat)
print('R (with unit):')
print(R_mat)  # upper triangular, carries the unit
Q (dimensionless):
[[-0.1690309   0.89708555]
 [-0.5070926   0.27602556]
 [-0.84515435 -0.34503233]]
R (with unit):
[[-5.91607952 -7.43735886]
 [ 0.          0.82808006]] m

cond — Condition number#

# Condition number is always dimensionless (ratio of singular values)
well_cond = jnp.array([[1., 0.], [0., 1.]]) * u.meter
ill_cond = jnp.array([[1., 1.], [1., 1.0001]]) * u.meter
print('Well-conditioned:', u.linalg.cond(well_cond))
print('Ill-conditioned:', u.linalg.cond(ill_cond))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[24], line 4
      1 # Condition number is always dimensionless (ratio of singular values)
      2 well_cond = jnp.array([[1., 0.], [0., 1.]]) * u.meter
      3 ill_cond = jnp.array([[1., 1.], [1., 1.0001]]) * u.meter
----> 4 print('Well-conditioned:', u.linalg.cond(well_cond))
      5 print('Ill-conditioned:', u.linalg.cond(ill_cond))

AttributeError: module 'brainunit.linalg' has no attribute 'cond'

matrix_rank#

full_rank = jnp.array([[1., 2.], [3., 4.]]) * u.meter
rank_def = jnp.array([[1., 2.], [2., 4.]]) * u.meter  # row 2 = 2 * row 1
print('Full rank matrix rank:', u.linalg.matrix_rank(full_rank))
print('Rank-deficient matrix rank:', u.linalg.matrix_rank(rank_def))
Full rank matrix rank: 2
Rank-deficient matrix rank: 1

slogdet — Sign and log of determinant#

M_det = jnp.array([[2., 1.], [1., 3.]]) * u.meter
sign, logabsdet = u.linalg.slogdet(M_det)
print('Sign:', sign)
print('Log |det|:', logabsdet)  # dimensionless
Sign: 1.0
Log |det|: 1.609438

Practical Example: Solving a Physical System#

A simple resistor network with Kirchhoff’s laws leads to a linear system.

# Three-node resistor network
# Conductance matrix G (siemens) and current sources I (ampere)
G = jnp.array([
    [0.3, -0.1, -0.1],
    [-0.1, 0.4, -0.2],
    [-0.1, -0.2, 0.5]
]) * u.siemens

I_source = jnp.array([1.0, 0.0, -0.5]) * u.ampere

# Solve for node voltages: G @ V = I  -->  V = solve(G, I)
V_nodes = u.linalg.solve(G, I_source)
print('Node voltages:', V_nodes)  # ampere / siemens = volt

# Verify: G @ V should equal I
print('Verification (G @ V):', u.linalg.matmul(G, V_nodes))
Node voltages: [3.71428561 1.00000012 0.14285721] V
Verification (G @ V): [ 9.9999994e-01  3.3101866e-08 -5.0000000e-01] A

Summary#

Function

Unit Behavior

Example

trace, diagonal, norm

Keeps input unit

trace([m]) m

dot, matmul, kron

Multiplies units

[m] @ [s] m*s

det

Unit^N for NxN

det([m]_{2x2}) m^2

inv, pinv

Reciprocal unit

inv([ohm]) 1/ohm

solve(A, b)

b_unit / A_unit

solve([ohm], [V]) A

cholesky

sqrt(unit)

chol([m^2]) m

svd, eig, qr

Factors carry units

S has unit, U/V dimensionless

cond, matrix_rank

Dimensionless

Always unitless output