Math Operations#
import brainunit as u
import jax.numpy as jnp
Like Numpy and Jax numpy, arithmetic operators on arrays apply elementwise.
a = [20, 30, 40, 50] * u.mV
b = jnp.arange(4) * u.mV
b
Quantity([0 1 2 3], "mV")
Addition and Subtraction#
Addition and subtraction of quantities need to have the same units and keep the units in the result.
c = a - b
c
Quantity([20 29 38 47], "mV")
c + b
Quantity([20 30 40 50], "mV")
Multiplication and Division#
Multiplication and division of quantities multiply and divide the values and add and subtract the dimensions of the units.
A = jnp.array([[1, 2], [3, 4]]) * u.mV
B = jnp.array([[5, 6], [7, 8]]) * u.mV
A, B
(Quantity([[1 2]
[3 4]], "mV"),
Quantity([[5 6]
[7 8]], "mV"))
A * B # element-wise multiplication
Quantity([[ 5 12]
[21 32]], "mV^2")
A @ B # matrix multiplication
Quantity([[19 22]
[43 50]], "mV^2")
A.dot(B) # matrix multiplication
Quantity([[19 22]
[43 50]], "mV^2")
A / 2 # divide by a scalar
Quantity([[0.5 1. ]
[1.5 2. ]], "mV")
if the unit of result is unitless, the unit is removed and returned as jax.Array
A / (2 * u.mV) # divide by a quantity, return jax array
Array([[0.5, 1. ],
[1.5, 2. ]], dtype=float32)
A / (2 * u.mA) # divide by a quantity, return quantity
Quantity([[0.5 1. ]
[1.5 2. ]], "ohm")
Power#
The power operator raises the value of the quantity to the power of the scalar, and multiplies the unit by the scalar.
A
Quantity([[1 2]
[3 4]], "mV")
A ** 2 # element-wise power
Quantity([[ 1 4]
[ 9 16]], "mV^2")
Built-in Functions#
brainunit provides a number of built-in functions in Quantity class to perform operations on quantities. These functions are:
unary operations
positive(+)
negative(-)
absolute(abs)
invert(~)
logical operations
all
any
shape operations
reshape
resize
squeeze
unsqueeze
spilt
swapaxes
transpose
ravel
take
repeat
diagonal
trace
mathematical functions
nonzero
argmax
argmin
argsort
var
round
std
sum
cumsum
cumprod
max
mean
min
ptp
clip
conj
dot
fill
item
prod
clamp
sort
For more details on these functions, refer to the documentation.
Indexing, Slicing and Iterating#
One-dimensional Quantity can be indexed, sliced and iterated over, much like lists and other Python sequences.
a = jnp.arange(10) ** 3 * u.mV
a
Quantity([ 0 1 8 27 64 125 216 343 512 729], "mV")
a[2]
Quantity(8, "mV")
a[2:5]
Quantity([ 8 27 64], "mV")
Only same dimension Quantity can be set to a slice of a Quantity.
# equivalent to a[0:6:2] = 1000;
# from start to position 6, exclusive, set every 2nd element to 1000
a[:6:2] = 1000 * u.mV
a
Quantity([1000 1 1000 27 1000 125 216 343 512 729], "mV")
a[::-1] # reversed a
Quantity([ 729 512 343 216 125 1000 27 1000 1 1000], "mV")
for i in a:
print(i**(1 / 3.))
10.000001 mV^0.3333333333333333
1. mV^0.3333333333333333
10.000001 mV^0.3333333333333333
3. mV^0.3333333333333333
10.000001 mV^0.3333333333333333
5.0000005 mV^0.3333333333333333
6.0000005 mV^0.3333333333333333
7.0000005 mV^0.3333333333333333
8.000001 mV^0.3333333333333333
9.000001 mV^0.3333333333333333
Multidimensional Quantity can have one index per axis. These indices are given in a tuple separated by commas:
def f(x, y):
return 10 * x + y
b = jnp.fromfunction(f, (5, 4), dtype=jnp.int32) * u.mV
b
Quantity([[ 0 1 2 3]
[10 11 12 13]
[20 21 22 23]
[30 31 32 33]
[40 41 42 43]], "mV")
b[2, 3]
Quantity(23, "mV")
b[0:5, 1] # each row in the second column of b
Quantity([ 1 11 21 31 41], "mV")
b[:, 1] # equivalent to the previous example
Quantity([ 1 11 21 31 41], "mV")
b[1:3, :] # each column in the second and third row of b
Quantity([[10 11 12 13]
[20 21 22 23]], "mV")
When fewer indices are provided than the number of axes, the missing indices are considered complete slices:
b[-1]
Quantity([40 41 42 43], "mV")
The expression within brackets in b[i] is treated as an i followed by as many instances of : as needed to represent the remaining axes. NumPy also allows you to write this using dots as b[i, …].
The dots (…) represent as many colons as needed to produce a complete indexing tuple. For example, if x is a Quantity with 5 axes, then
x[1, 2, …] is equivalent to x[1, 2, :, :, :],
x[…, 3] to x[:, :, :, :, 3] and
x[4, …, 5, :] to x[4, :, :, 5, :].
c = jnp.array([[[0, 1, 2], [10, 12, 13]], [[100, 101, 102], [110, 112, 113]]]) * u.mV # a 3D array (two stacked 2D arrays)
c.shape
(2, 2, 3)
c[1, ...] # same as c[1, :, :] or c[1]
Quantity([[100 101 102]
[110 112 113]], "mV")
c[..., 2] # same as c[:, :, 2]
Quantity([[ 2 13]
[102 113]], "mV")
Iterating over multidimensional Quantity is done with respect to the first axis:
for row in b:
print(row)
[0 1 2 3] mV
[10 11 12 13] mV
[20 21 22 23] mV
[30 31 32 33] mV
[40 41 42 43] mV
Operating on Subsets#
.at method can be used to operate on a subset of the Quantity. The following are examples of operating on subsets of a Quantity:
q = jnp.arange(5.0) * u.mV
q
Quantity([0. 1. 2. 3. 4.], "mV")
q.at[2].add(10 * u.mV)
Quantity([ 0. 1. 12. 3. 4.], "mV")
q.at[10].add(10 * u.mV) # out-of-bounds indices are ignored
Quantity([0. 1. 2. 3. 4.], "mV")
q.at[20].add(10 * u.mV, mode='clip') # out-of-bounds indices are clipped
Quantity([ 0. 1. 2. 3. 14.], "mV")
q.at[2].get()
Quantity(2., "mV")
q.at[20].get() # out-of-bounds indices clipped
Quantity(4., "mV")
q.at[20].get(mode='fill') # out-of-bounds indices filled with NaN
Quantity(nan, "mV")
brainunit will check the consistency of operations on units and raise an error for dimensionality mismatches:
try:
q.at[2].add(10)
except Exception as e:
print(e)
Cannot convert to a unit with different dimensions. (units are 1 and mV).
brainunit also allows customized fill values for the at method:
q.at[20].get(mode='fill', fill_value=-1 * u.mV) # custom fill value
Quantity(-1., "mV")
try:
q.at[20].get(mode='fill', fill_value=-1)
except Exception as e:
print(e)
Cannot convert to a unit with different dimensions. (units are 1 and mV).