JAX LAX Functions#
brainunit.lax provides unit-aware wrappers around JAX’s low-level jax.lax primitives.
These are the building blocks that higher-level functions (like brainunit.math) are built on.
The functions are grouped by how they handle units:
Keeping unit: slicing, sorting, cumulative ops, padding, broadcasting
Changing unit: arithmetic (
mul,div,integer_pow),rsqrt,dot_general,batch_matmul,convRemoving unit: comparisons (
eq,lt,gt, …)Accepting unitless: trig, special functions (
erf,logistic,bessel)Linear algebra:
cholesky,eig,qr,svd,triangular_solve
import brainunit as u
import jax.numpy as jnp
Functions That Keep Unit#
These operations rearrange, slice, or accumulate values without changing units.
Slicing Operations#
x = jnp.array([10., 20., 30., 40., 50.]) * u.volt
# Static slice: elements from index 1 to 4
print('slice [1:4]:', u.lax.slice(x, (1,), (4,)))
# Dynamic slice: start at index 2, take 3 elements
print('dynamic_slice:', u.lax.dynamic_slice(x, (2,), (3,)))
# Slice in a specific dimension
M = jnp.arange(12.).reshape(3, 4) * u.ampere
print('slice_in_dim (rows 0:2):')
print(u.lax.slice_in_dim(M, 0, 2, axis=0))
slice [1:4]: [20. 30. 40.] V
dynamic_slice: [30. 40. 50.] V
slice_in_dim (rows 0:2):
[[0. 1. 2. 3.]
[4. 5. 6. 7.]] A
Dynamic Updates#
arr = jnp.array([1., 2., 3., 4., 5.]) * u.meter
update = jnp.array([99., 88.]) * u.meter
# Update a slice starting at index 1
result = u.lax.dynamic_update_slice(arr, update, (1,))
print('dynamic_update_slice:', result) # [1, 99, 88, 4, 5] m
dynamic_update_slice: [ 1. 99. 88. 4. 5.] m
Sorting#
unsorted = jnp.array([3., 1., 4., 1., 5., 9., 2., 6.]) * u.newton
print('sort:', u.lax.sort(unsorted))
sort: [1. 1. 2. 3. 4. 5. 6. 9.] N
# Top-k: largest k elements
values, indices = u.lax.top_k(unsorted, 3)
print('top_k values:', values)
print('top_k indices:', indices)
top_k values: [9. 6. 5.] N
top_k indices: [5 7 4]
Cumulative Operations#
vals = jnp.array([1., 3., 2., 5., 4.]) * u.watt
print('cumsum:', u.lax.cumsum(vals, axis=0))
print('cummin:', u.lax.cummin(vals, axis=0))
print('cummax:', u.lax.cummax(vals, axis=0))
cumsum: [ 1. 4. 6. 11. 15.] W
cummin: [1. 1. 1. 1. 1.] W
cummax: [1. 3. 3. 5. 5.] W
Padding#
signal = jnp.array([1., 2., 3.]) * u.volt
# Pad with 2 zeros on the left and 1 on the right
padded = u.lax.pad(signal, 0.0 * u.volt, [(2, 1, 0)])
print('padded:', padded) # [0, 0, 1, 2, 3, 0] V
padded:
[0. 0. 1. 2. 3. 0.] V
Clamping#
x = jnp.array([0.5, 1.5, 2.5, 3.5, 4.5]) * u.volt
clamped = u.lax.clamp(1.0 * u.volt, x, 3.0 * u.volt)
print('clamp [1V, 3V]:', clamped)
clamp [1V, 3V]: [1. 1.5 2.5 3. 3. ] V
Broadcasting#
v = jnp.array([1., 2., 3.]) * u.meter
print('broadcast to (4, 3):')
print(u.lax.broadcast(v, (4,))) # adds a leading dimension of size 4
broadcast to (4, 3):
[[1. 2. 3.]
[1. 2. 3.]
[1. 2. 3.]
[1. 2. 3.]] m
Negation#
print('neg:', u.lax.neg(jnp.array([1., -2., 3.]) * u.pascal))
neg: [-1. 2. -3.] Pa
Type Conversion#
x_int = jnp.array([1, 2, 3]) * u.meter
x_float = u.lax.convert_element_type(x_int, jnp.float32)
print('int to float:', x_float)
print('dtype:', x_float.dtype)
int to float: [1. 2. 3.] m
dtype: float32
Functions That Change Unit#
These arithmetic and algebraic operations produce results with different units.
Arithmetic: mul, div, sub#
current = jnp.array([1., 2., 3.]) * u.ampere
resistance = jnp.array([10., 20., 30.]) * u.ohm
# Ohm's law: V = I * R
voltage = u.lax.mul(current, resistance)
print('V = I * R:', voltage) # ampere * ohm = volt
V = I * R: [10. 40. 90.] V
# Division
power = jnp.array([100., 200.]) * u.watt
v = jnp.array([10., 20.]) * u.volt
print('P / V:', u.lax.div(power, v)) # watt / volt = ampere
P / V:
[10. 10.] A
integer_pow — Power with integer exponent#
lengths = jnp.array([2., 3., 4.]) * u.meter
print('squared:', u.lax.integer_pow(lengths, 2)) # m^2
print('cubed:', u.lax.integer_pow(lengths, 3)) # m^3
squared:
[ 4. 9. 16.] m^2
cubed: [ 8. 27. 64.] m^3
rsqrt — Reciprocal square root#
For input with unit u, result has unit 1/sqrt(u).
areas = jnp.array([4., 9., 16.]) * u.meter2
print('rsqrt:', u.lax.rsqrt(areas)) # 1/m
rsqrt: [0.5 0.33333334 0.25 ] 1 / m
dot_general — Generalized dot product#
# Matrix multiplication via dot_general
a = jnp.array([[1., 2.], [3., 4.]]) * u.meter
b = jnp.array([[5., 6.], [7., 8.]]) * u.second
# Contract over last axis of a and first axis of b
result = u.lax.dot_general(a, b, (((1,), (0,)), ((), ())))
print('dot_general (matmul):')
print(result) # m * s
dot_general (matmul):
[[19. 22.]
[43. 50.]] m * s
batch_matmul — Batched matrix multiplication#
# Batch of 2 matrices: (batch=2, rows=3, cols=4) @ (batch=2, rows=4, cols=2)
A = jnp.ones((2, 3, 4)) * u.volt
B = jnp.ones((2, 4, 2)) * u.ampere
C = u.lax.batch_matmul(A, B)
print('batch_matmul shape:', C.shape) # (2, 3, 2)
print('batch_matmul unit:', C.unit) # V * A = W
batch_matmul shape: (2, 3, 2)
batch_matmul unit: W
rem — Remainder#
a = jnp.array([7., 10., 15.]) * u.meter
b = jnp.array([3., 4., 6.]) * u.meter
print('remainder:', u.lax.rem(a, b)) # [1, 2, 3] m
remainder: [1. 2. 3.] m
Functions That Remove Unit (Comparisons)#
Comparison operations return dimensionless boolean arrays.
a = jnp.array([1., 3., 5.]) * u.volt
b = jnp.array([2., 3., 4.]) * u.volt
print('eq:', u.lax.eq(a, b)) # [F, T, F]
print('ne:', u.lax.ne(a, b)) # [T, F, T]
print('lt:', u.lax.lt(a, b)) # [T, F, F]
print('le:', u.lax.le(a, b)) # [T, T, F]
print('gt:', u.lax.gt(a, b)) # [F, F, T]
print('ge:', u.lax.ge(a, b)) # [F, T, T]
eq: [False True False]
ne: [ True False True]
lt: [ True False False]
le: [ True True False]
gt: [False False True]
ge: [False True True]
Functions Accepting Unitless Input#
These functions (trigonometric, special functions) require dimensionless inputs.
# Trigonometric functions require unitless input
angles = jnp.array([0.0, 0.5, 1.0])
print('asin:', u.lax.asin(angles))
print('acos:', u.lax.acos(angles))
print('atan:', u.lax.atan(angles))
asin: [0. 0.5235988 1.5707964]
acos: [1.5707964 1.0471976 0. ]
atan: [0. 0.4636476 0.7853982]
# Special functions
x = jnp.array([0.0, 0.5, 1.0, 2.0])
print('logistic:', u.lax.logistic(x))
print('erf:', u.lax.erf(x))
logistic: [0.5 0.62245935 0.7310586 0.880797 ]
erf: [0. 0.5204999 0.84270084 0.9953222 ]
# Bessel functions
print('bessel_i0e:', u.lax.bessel_i0e(x))
print('bessel_i1e:', u.lax.bessel_i1e(x))
bessel_i0e: [1.0000001 0.6450353 0.46575963 0.3085083 ]
bessel_i1e: [0. 0.1564208 0.20791042 0.21526928]
# Passing a quantity with units raises an error
try:
u.lax.logistic(jnp.array([1.0, 2.0]) * u.volt)
except Exception as e:
print(type(e).__name__, ':', e)
TypeError : logistic requires a dimensionless "x" when "unit_to_scale" is not provided. Got Quantity(unit=V, dim=m^2 kg s^-3 A^-1). Pass "unit_to_scale=<Unit>" to scale before applying logistic, or convert explicitly to a dimensionless value first.
LAX Linear Algebra#
brainunit.lax also provides low-level linear algebra primitives.
# Cholesky decomposition (positive definite matrix)
M = jnp.array([[4., 2.], [2., 3.]]) * u.meter2
L = u.lax.cholesky(M)
print('cholesky:')
print(L) # unit: m (sqrt of m^2)
cholesky:
[[2. 0. ]
[1. 1.41421354]] m
# QR decomposition
A = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) * u.newton
Q, R = u.lax.qr(A)
print('Q (orthogonal, dimensionless):')
print(Q)
print('R (upper triangular, with unit):')
print(R)
Q (orthogonal, dimensionless):
[[-0.1690309 0.89708555 0.40824738]
[-0.5070926 0.27602556 -0.8164968 ]
[-0.84515435 -0.34503233 0.40824866]]
R (upper triangular, with unit):
[[-5.91607952 -7.43735886]
[ 0. 0.82808006]
[ 0. 0. ]] N
# Triangular solve: solve L @ x = b where L is lower triangular
L = jnp.array([[2., 0.], [1., 3.]]) * u.ohm
b = jnp.array([4., 7.]) * u.volt
x = u.lax.triangular_solve(L, b, left_side=True, lower=True)
print('triangular_solve:', x) # V / ohm = A
triangular_solve: [2. 1.66666675] V
Array Creation#
# Create an index array (always dimensionless)
idx = u.lax.iota(jnp.int32, 5)
print('iota:', idx)
# zeros_like_array
template = jnp.array([1., 2., 3.]) * u.volt
print('zeros_like:', u.lax.zeros_like_array(template))
iota: [0 1 2 3 4]
zeros_like: [0. 0. 0.] V
Summary#
Category |
Functions |
Unit Behavior |
|---|---|---|
Slicing |
|
Keep unit |
Sorting |
|
Keep unit |
Cumulative |
|
Keep unit |
Layout |
|
Keep unit |
Arithmetic |
|
Change unit |
Products |
|
Change unit |
Comparisons |
|
Remove unit |
Trig/Special |
|
Require unitless |
Linalg |
|
Varies |