Math Operations#

Colab Open in Kaggle

import brainunit as u
import jax.numpy as jnp

Like Numpy and Jax numpy, arithmetic operators on arrays apply elementwise.

a = [20, 30, 40, 50] * u.mV
b = jnp.arange(4) * u.mV
b
Quantity([0 1 2 3], "mV")

Addition and Subtraction#

Addition and subtraction of quantities need to have the same units and keep the units in the result.

c = a - b
c
Quantity([20 29 38 47], "mV")
c + b
Quantity([20 30 40 50], "mV")

Multiplication and Division#

Multiplication and division of quantities multiply and divide the values and add and subtract the dimensions of the units.

A = jnp.array([[1, 2], [3, 4]]) * u.mV
B = jnp.array([[5, 6], [7, 8]]) * u.mV

A, B
(Quantity([[1 2]
           [3 4]], "mV"),
 Quantity([[5 6]
           [7 8]], "mV"))
A * B # element-wise multiplication
Quantity([[ 5 12]
          [21 32]], "mV^2")
A @ B # matrix multiplication
Quantity([[19 22]
          [43 50]], "mV^2")
A.dot(B) # matrix multiplication
Quantity([[19 22]
          [43 50]], "mV^2")
A / 2 # divide by a scalar
Quantity([[0.5 1. ]
          [1.5 2. ]], "mV")

if the unit of result is unitless, the unit is removed and returned as jax.Array

A / (2 * u.mV) # divide by a quantity, return jax array
Array([[0.5, 1. ],
       [1.5, 2. ]], dtype=float32)
A / (2 * u.mA) # divide by a quantity, return quantity
Quantity([[0.5 1. ]
          [1.5 2. ]], "ohm")

Power#

The power operator raises the value of the quantity to the power of the scalar, and multiplies the unit by the scalar.

A
Quantity([[1 2]
          [3 4]], "mV")
A ** 2 # element-wise power
Quantity([[ 1  4]
          [ 9 16]], "mV^2")

Built-in Functions#

brainunit provides a number of built-in functions in Quantity class to perform operations on quantities. These functions are:

  • unary operations

    • positive(+)

    • negative(-)

    • absolute(abs)

    • invert(~)

  • logical operations

    • all

    • any

  • shape operations

    • reshape

    • resize

    • squeeze

    • unsqueeze

    • spilt

    • swapaxes

    • transpose

    • ravel

    • take

    • repeat

    • diagonal

    • trace

  • mathematical functions

    • nonzero

    • argmax

    • argmin

    • argsort

    • var

    • round

    • std

    • sum

    • cumsum

    • cumprod

    • max

    • mean

    • min

    • ptp

    • clip

    • conj

    • dot

    • fill

    • item

    • prod

    • clamp

    • sort

For more details on these functions, refer to the documentation.

Indexing, Slicing and Iterating#

One-dimensional Quantity can be indexed, sliced and iterated over, much like lists and other Python sequences.

a = jnp.arange(10) ** 3 * u.mV
a
Quantity([  0   1   8  27  64 125 216 343 512 729], "mV")
a[2]
Quantity(8, "mV")
a[2:5]
Quantity([ 8 27 64], "mV")

Only same dimension Quantity can be set to a slice of a Quantity.

# equivalent to a[0:6:2] = 1000;
# from start to position 6, exclusive, set every 2nd element to 1000
a[:6:2] = 1000 * u.mV
a
Quantity([1000    1 1000   27 1000  125  216  343  512  729], "mV")
a[::-1] # reversed a
Quantity([ 729  512  343  216  125 1000   27 1000    1 1000], "mV")
for i in a:
    print(i**(1 / 3.))
10.000001 mV^0.3333333333333333
1. mV^0.3333333333333333
10.000001 mV^0.3333333333333333
3. mV^0.3333333333333333
10.000001 mV^0.3333333333333333
5.0000005 mV^0.3333333333333333
6.0000005 mV^0.3333333333333333
7.0000005 mV^0.3333333333333333
8.000001 mV^0.3333333333333333
9.000001 mV^0.3333333333333333

Multidimensional Quantity can have one index per axis. These indices are given in a tuple separated by commas:

def f(x, y):
    return 10 * x + y
b = jnp.fromfunction(f, (5, 4), dtype=jnp.int32) * u.mV
b
Quantity([[ 0  1  2  3]
          [10 11 12 13]
          [20 21 22 23]
          [30 31 32 33]
          [40 41 42 43]], "mV")
b[2, 3]
Quantity(23, "mV")
b[0:5, 1]  # each row in the second column of b
Quantity([ 1 11 21 31 41], "mV")
b[:, 1]  # equivalent to the previous example
Quantity([ 1 11 21 31 41], "mV")
b[1:3, :]  # each column in the second and third row of b
Quantity([[10 11 12 13]
          [20 21 22 23]], "mV")

When fewer indices are provided than the number of axes, the missing indices are considered complete slices:

b[-1]
Quantity([40 41 42 43], "mV")

The expression within brackets in b[i] is treated as an i followed by as many instances of : as needed to represent the remaining axes. NumPy also allows you to write this using dots as b[i, …].

The dots (…) represent as many colons as needed to produce a complete indexing tuple. For example, if x is a Quantity with 5 axes, then

  • x[1, 2, …] is equivalent to x[1, 2, :, :, :],

  • x[…, 3] to x[:, :, :, :, 3] and

  • x[4, …, 5, :] to x[4, :, :, 5, :].

c = jnp.array([[[0, 1, 2], [10, 12, 13]], [[100, 101, 102], [110, 112, 113]]]) * u.mV # a 3D array (two stacked 2D arrays)
c.shape
(2, 2, 3)
c[1, ...] # same as c[1, :, :] or c[1]
Quantity([[100 101 102]
          [110 112 113]], "mV")
c[..., 2] # same as c[:, :, 2]
Quantity([[  2  13]
          [102 113]], "mV")

Iterating over multidimensional Quantity is done with respect to the first axis:

for row in b:
    print(row)
[0 1 2 3] mV
[10 11 12 13] mV
[20 21 22 23] mV
[30 31 32 33] mV
[40 41 42 43] mV

Operating on Subsets#

.at method can be used to operate on a subset of the Quantity. The following are examples of operating on subsets of a Quantity:

q = jnp.arange(5.0) * u.mV
q
Quantity([0. 1. 2. 3. 4.], "mV")
q.at[2].add(10 * u.mV)
Quantity([ 0.  1. 12.  3.  4.], "mV")
q.at[10].add(10 * u.mV)  # out-of-bounds indices are ignored
Quantity([0. 1. 2. 3. 4.], "mV")
q.at[20].add(10 * u.mV, mode='clip') # out-of-bounds indices are clipped
Quantity([ 0.  1.  2.  3. 14.], "mV")
q.at[2].get()
Quantity(2., "mV")
q.at[20].get()  # out-of-bounds indices clipped
Quantity(4., "mV")
q.at[20].get(mode='fill')  # out-of-bounds indices filled with NaN
Quantity(nan, "mV")

brainunit will check the consistency of operations on units and raise an error for dimensionality mismatches:

try:
    q.at[2].add(10)
except Exception as e:
    print(e)
Cannot convert to a unit with different dimensions. (units are 1 and mV).

brainunit also allows customized fill values for the at method:

q.at[20].get(mode='fill', fill_value=-1 * u.mV)  # custom fill value
Quantity(-1., "mV")
try:
    q.at[20].get(mode='fill', fill_value=-1)
except Exception as e:
    print(e)
Cannot convert to a unit with different dimensions. (units are 1 and mV).