JAX LAX Functions#

Colab Open in Kaggle

brainunit.lax provides unit-aware wrappers around JAX’s low-level jax.lax primitives. These are the building blocks that higher-level functions (like brainunit.math) are built on.

The functions are grouped by how they handle units:

  • Keeping unit: slicing, sorting, cumulative ops, padding, broadcasting

  • Changing unit: arithmetic (mul, div, integer_pow), rsqrt, dot_general, batch_matmul, conv

  • Removing unit: comparisons (eq, lt, gt, …)

  • Accepting unitless: trig, special functions (erf, logistic, bessel)

  • Linear algebra: cholesky, eig, qr, svd, triangular_solve

import brainunit as u
import jax.numpy as jnp

Functions That Keep Unit#

These operations rearrange, slice, or accumulate values without changing units.

Slicing Operations#

x = jnp.array([10., 20., 30., 40., 50.]) * u.volt

# Static slice: elements from index 1 to 4
print('slice [1:4]:', u.lax.slice(x, (1,), (4,)))

# Dynamic slice: start at index 2, take 3 elements
print('dynamic_slice:', u.lax.dynamic_slice(x, (2,), (3,)))

# Slice in a specific dimension
M = jnp.arange(12.).reshape(3, 4) * u.ampere
print('slice_in_dim (rows 0:2):')
print(u.lax.slice_in_dim(M, 0, 2, axis=0))
slice [1:4]: [20. 30. 40.] V
dynamic_slice: [30. 40. 50.] V
slice_in_dim (rows 0:2):
[[0. 1. 2. 3.]
 [4. 5. 6. 7.]] A

Dynamic Updates#

arr = jnp.array([1., 2., 3., 4., 5.]) * u.meter
update = jnp.array([99., 88.]) * u.meter

# Update a slice starting at index 1
result = u.lax.dynamic_update_slice(arr, update, (1,))
print('dynamic_update_slice:', result)  # [1, 99, 88, 4, 5] m
dynamic_update_slice: [ 1. 99. 88.  4.  5.] m

Sorting#

unsorted = jnp.array([3., 1., 4., 1., 5., 9., 2., 6.]) * u.newton
print('sort:', u.lax.sort(unsorted))
sort: [1. 1. 2. 3. 4. 5. 6. 9.] N
# Top-k: largest k elements
values, indices = u.lax.top_k(unsorted, 3)
print('top_k values:', values)
print('top_k indices:', indices)
top_k values: [9. 6. 5.] N
top_k indices: [5 7 4]

Cumulative Operations#

vals = jnp.array([1., 3., 2., 5., 4.]) * u.watt

print('cumsum:', u.lax.cumsum(vals, axis=0))
print('cummin:', u.lax.cummin(vals, axis=0))
print('cummax:', u.lax.cummax(vals, axis=0))
cumsum: [ 1.  4.  6. 11. 15.] W
cummin: [1. 1. 1. 1. 1.] W
cummax: [1. 3. 3. 5. 5.] W

Padding#

signal = jnp.array([1., 2., 3.]) * u.volt

# Pad with 2 zeros on the left and 1 on the right
padded = u.lax.pad(signal, 0.0 * u.volt, [(2, 1, 0)])
print('padded:', padded)  # [0, 0, 1, 2, 3, 0] V
padded: 
[0. 0. 1. 2. 3. 0.] V

Clamping#

x = jnp.array([0.5, 1.5, 2.5, 3.5, 4.5]) * u.volt
clamped = u.lax.clamp(1.0 * u.volt, x, 3.0 * u.volt)
print('clamp [1V, 3V]:', clamped)
clamp [1V, 3V]: [1.  1.5 2.5 3.  3. ] V

Broadcasting#

v = jnp.array([1., 2., 3.]) * u.meter
print('broadcast to (4, 3):')
print(u.lax.broadcast(v, (4,)))  # adds a leading dimension of size 4
broadcast to (4, 3):
[[1. 2. 3.]
 [1. 2. 3.]
 [1. 2. 3.]
 [1. 2. 3.]] m

Negation#

print('neg:', u.lax.neg(jnp.array([1., -2., 3.]) * u.pascal))
neg: [-1.  2. -3.] Pa

Type Conversion#

x_int = jnp.array([1, 2, 3]) * u.meter
x_float = u.lax.convert_element_type(x_int, jnp.float32)
print('int to float:', x_float)
print('dtype:', x_float.dtype)
int to float: [1. 2. 3.] m
dtype: float32

Functions That Change Unit#

These arithmetic and algebraic operations produce results with different units.

Arithmetic: mul, div, sub#

current = jnp.array([1., 2., 3.]) * u.ampere
resistance = jnp.array([10., 20., 30.]) * u.ohm

# Ohm's law: V = I * R
voltage = u.lax.mul(current, resistance)
print('V = I * R:', voltage)  # ampere * ohm = volt
V = I * R: [10. 40. 90.] V
# Division
power = jnp.array([100., 200.]) * u.watt
v = jnp.array([10., 20.]) * u.volt
print('P / V:', u.lax.div(power, v))  # watt / volt = ampere
P / V: 
[10. 10.] A

integer_pow — Power with integer exponent#

lengths = jnp.array([2., 3., 4.]) * u.meter
print('squared:', u.lax.integer_pow(lengths, 2))  # m^2
print('cubed:', u.lax.integer_pow(lengths, 3))    # m^3
squared: 
[ 4.  9. 16.] m^2
cubed: [ 8. 27. 64.] m^3

rsqrt — Reciprocal square root#

For input with unit u, result has unit 1/sqrt(u).

areas = jnp.array([4., 9., 16.]) * u.meter2
print('rsqrt:', u.lax.rsqrt(areas))  # 1/m
rsqrt: [0.5        0.33333334 0.25      ] 1 / m

dot_general — Generalized dot product#

# Matrix multiplication via dot_general
a = jnp.array([[1., 2.], [3., 4.]]) * u.meter
b = jnp.array([[5., 6.], [7., 8.]]) * u.second

# Contract over last axis of a and first axis of b
result = u.lax.dot_general(a, b, (((1,), (0,)), ((), ())))
print('dot_general (matmul):')
print(result)  # m * s
dot_general (matmul):
[[19. 22.]
 [43. 50.]] m * s

batch_matmul — Batched matrix multiplication#

# Batch of 2 matrices: (batch=2, rows=3, cols=4) @ (batch=2, rows=4, cols=2)
A = jnp.ones((2, 3, 4)) * u.volt
B = jnp.ones((2, 4, 2)) * u.ampere

C = u.lax.batch_matmul(A, B)
print('batch_matmul shape:', C.shape)  # (2, 3, 2)
print('batch_matmul unit:', C.unit)    # V * A = W
batch_matmul shape: (2, 3, 2)
batch_matmul unit: W

rem — Remainder#

a = jnp.array([7., 10., 15.]) * u.meter
b = jnp.array([3., 4., 6.]) * u.meter
print('remainder:', u.lax.rem(a, b))  # [1, 2, 3] m
remainder: [1. 2. 3.] m

Functions That Remove Unit (Comparisons)#

Comparison operations return dimensionless boolean arrays.

a = jnp.array([1., 3., 5.]) * u.volt
b = jnp.array([2., 3., 4.]) * u.volt

print('eq:', u.lax.eq(a, b))    # [F, T, F]
print('ne:', u.lax.ne(a, b))    # [T, F, T]
print('lt:', u.lax.lt(a, b))    # [T, F, F]
print('le:', u.lax.le(a, b))    # [T, T, F]
print('gt:', u.lax.gt(a, b))    # [F, F, T]
print('ge:', u.lax.ge(a, b))    # [F, T, T]
eq: [False  True False]
ne: [ True False  True]
lt: [ True False False]
le: [ True  True False]
gt: [False False  True]
ge: [False  True  True]

Functions Accepting Unitless Input#

These functions (trigonometric, special functions) require dimensionless inputs.

# Trigonometric functions require unitless input
angles = jnp.array([0.0, 0.5, 1.0])
print('asin:', u.lax.asin(angles))
print('acos:', u.lax.acos(angles))
print('atan:', u.lax.atan(angles))
asin: [0.        0.5235988 1.5707964]
acos: [1.5707964 1.0471976 0.       ]
atan: [0.        0.4636476 0.7853982]
# Special functions
x = jnp.array([0.0, 0.5, 1.0, 2.0])
print('logistic:', u.lax.logistic(x))
print('erf:', u.lax.erf(x))
logistic: [0.5        0.62245935 0.7310586  0.880797  ]
erf: [0.         0.5204999  0.84270084 0.9953222 ]
# Bessel functions
print('bessel_i0e:', u.lax.bessel_i0e(x))
print('bessel_i1e:', u.lax.bessel_i1e(x))
bessel_i0e: [1.0000001  0.6450353  0.46575963 0.3085083 ]
bessel_i1e: [0.         0.1564208  0.20791042 0.21526928]
# Passing a quantity with units raises an error
try:
    u.lax.logistic(jnp.array([1.0, 2.0]) * u.volt)
except Exception as e:
    print(type(e).__name__, ':', e)
TypeError : logistic requires a dimensionless "x" when "unit_to_scale" is not provided. Got Quantity(unit=V, dim=m^2 kg s^-3 A^-1). Pass "unit_to_scale=<Unit>" to scale before applying logistic, or convert explicitly to a dimensionless value first.

LAX Linear Algebra#

brainunit.lax also provides low-level linear algebra primitives.

# Cholesky decomposition (positive definite matrix)
M = jnp.array([[4., 2.], [2., 3.]]) * u.meter2
L = u.lax.cholesky(M)
print('cholesky:')
print(L)  # unit: m (sqrt of m^2)
cholesky:
[[2.        0.       ]
 [1.        1.41421354]] m
# QR decomposition
A = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) * u.newton
Q, R = u.lax.qr(A)
print('Q (orthogonal, dimensionless):')
print(Q)
print('R (upper triangular, with unit):')
print(R)
Q (orthogonal, dimensionless):
[[-0.1690309   0.89708555  0.40824738]
 [-0.5070926   0.27602556 -0.8164968 ]
 [-0.84515435 -0.34503233  0.40824866]]
R (upper triangular, with unit):
[[-5.91607952 -7.43735886]
 [ 0.          0.82808006]
 [ 0.          0.        ]] N
# Triangular solve: solve L @ x = b  where L is lower triangular
L = jnp.array([[2., 0.], [1., 3.]]) * u.ohm
b = jnp.array([4., 7.]) * u.volt

x = u.lax.triangular_solve(L, b, left_side=True, lower=True)
print('triangular_solve:', x)  # V / ohm = A
triangular_solve: [2.        1.66666675] V

Array Creation#

# Create an index array (always dimensionless)
idx = u.lax.iota(jnp.int32, 5)
print('iota:', idx)

# zeros_like_array
template = jnp.array([1., 2., 3.]) * u.volt
print('zeros_like:', u.lax.zeros_like_array(template))
iota: [0 1 2 3 4]
zeros_like: [0. 0. 0.] V

Summary#

Category

Functions

Unit Behavior

Slicing

slice, dynamic_slice, slice_in_dim

Keep unit

Sorting

sort, top_k

Keep unit

Cumulative

cumsum, cummin, cummax

Keep unit

Layout

pad, clamp, broadcast, neg

Keep unit

Arithmetic

mul, div, integer_pow, rsqrt

Change unit

Products

dot_general, batch_matmul, conv

Change unit

Comparisons

eq, ne, lt, le, gt, ge

Remove unit

Trig/Special

asin, logistic, erf, bessel_*

Require unitless

Linalg

cholesky, qr, svd, triangular_solve

Varies