Array Creation#
The functions listed below are designed to create array or Quantity with specific properties, such as filled with a certain value, identity matrices, or arrays with ones on the diagonal. These functions are part of the brainunit.math module and are tailored to handle both numerical array and Quantity with units.
import brainunit as u
import jax.numpy as jnp
brainunit.math.array & brainunit.math.asarray#
Convert the input to a quantity or array.
If unit is provided, the input will be checked whether it has the same unit as the provided unit. (If they have same dimension but different magnitude, the input will be converted to the provided unit.) If unit is not provided, the input will be converted to an array.
u.math.asarray([1, 2, 3]) # return a jax.Array
Array([1, 2, 3], dtype=int32)
u.math.asarray([1, 2, 3], unit=u.second) # return a Quantity
Array([1, 2, 3], dtype=int32)
# check if the input has the same unit as the provided unit
u.math.asarray([1 * u.second, 2 * u.second], unit=u.second)
Quantity([1 2], "s")
# fails because the input has a different unit
try:
u.math.asarray([1 * u.second, 2 * u.second], unit=u.ampere)
except Exception as e:
print(e)
Cannot convert to a unit with different dimensions. (units are s and A).
brainunit.math.arange#
Return evenly spaced values within a given interval.
u.math.arange(5) # return a jax.Array
Array([0, 1, 2, 3, 4], dtype=int32)
u.math.arange(5 * u.second, step=1 * u.second) # return a Quantity
Quantity([0 1 2 3 4], "s")
u.math.arange(3, 9, 1) # return a jax.Array
Array([3, 4, 5, 6, 7, 8], dtype=int32)
u.math.arange(3 * u.second, 9 * u.second, 1 * u.second) # return a Quantity
Quantity([3 4 5 6 7 8], "s")
brainunit.math.array_split#
Split an array into multiple sub-arrays.
a = jnp.arange(9)
u.math.array_split(a, 3) # return a jax.Array
[Array([0, 1, 2], dtype=int32),
Array([3, 4, 5], dtype=int32),
Array([6, 7, 8], dtype=int32)]
q = jnp.arange(9) * u.second
u.math.array_split(q, 3) # return a Quantity
[Quantity([0 1 2], "s"), Quantity([3 4 5], "s"), Quantity([6 7 8], "s")]
brainunit.math.linspace#
Return evenly spaced numbers over a specified interval.
u.math.linspace(0, 10, 5) # return a jax.Array
Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32)
u.math.linspace(0 * u.second, 10 * u.second, 5) # return a Quantity
Quantity([ 0. 2.5 5. 7.5 10. ], "s")
brainunit.math.logspace#
Return numbers spaced evenly on a log scale.
u.math.logspace(0, 10, 5) # return a jax.Array
Array([1.0000000e+00, 3.1622775e+02, 1.0000000e+05, 3.1622776e+07,
1.0000000e+10], dtype=float32)
u.math.logspace(0 * u.second, 10 * u.second, 5) # return a Quantity
Quantity([1.0000000e+00 3.1622775e+02 1.0000000e+05 3.1622776e+07 1.0000000e+10], "s")
brainunit.math.meshgrid#
Return coordinate matrices from coordinate vectors.
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5])
u.math.meshgrid(x, y) # return a jax.Array
[Array([[1, 2, 3],
[1, 2, 3]], dtype=int32),
Array([[4, 4, 4],
[5, 5, 5]], dtype=int32)]
x_q = jnp.array([1, 2, 3]) * u.second
y_q = jnp.array([4, 5]) * u.second
u.math.meshgrid(x_q, y_q) # return a Quantity
[Quantity([[1 2 3]
[1 2 3]], "s"),
Quantity([[4 4 4]
[5 5 5]], "s")]
brainunit.math.vander#
Generate a Vandermonde matrix.
The Vandermonde matrix is a matrix with the terms of a geometric progression in each row.
The geometric progression is defined by the vector x and the number of columns N.
a = jnp.array([1, 2, 3])
u.math.vander(a) # return a jax.Array
Array([[1, 1, 1],
[4, 2, 1],
[9, 3, 1]], dtype=int32)
Can use with Quantity#
The functions below can be used with Quantity with units.
brainunit.math.full#
Returns a quantity or array filled with a specific value.
u.math.full(3, 4) # return a jax.Array
Array([4, 4, 4], dtype=int32, weak_type=True)
u.math.full(3, 4 * u.second) # return a Quantity
Quantity([4 4 4], "s")
brainunit.math.empty#
Return a new quantity or array of given shape and type, without initializing entries.
u.math.empty((2, 2)) # return a jax.Array
Array([[0., 0.],
[0., 0.]], dtype=float32)
u.math.empty((2, 2), unit=u.second) # return a Quantity
Quantity([[0. 0.]
[0. 0.]], "s")
brainunit.math.ones#
Returns a new quantity or array of given shape and type, filled with ones.
u.math.ones((2, 2)) # return a jax.Array
Array([[1., 1.],
[1., 1.]], dtype=float32)
u.math.ones((2, 2), unit=u.second) # return a Quantity
Quantity([[1. 1.]
[1. 1.]], "s")
brainunit.math.zeros#
Returns a new quantity or array of given shape and type, filled with ones.
u.math.zeros((2, 2)) # return a jax.Array
Array([[0., 0.],
[0., 0.]], dtype=float32)
u.math.zeros((2, 2), unit=u.second) # return a Quantity
Quantity([[0. 0.]
[0. 0.]], "s")
brainunit.math.full_like#
Return a new quantity or array with the same shape and type as a given array or quantity, filled with fill_value.
a = jnp.array([1, 2, 3])
u.math.full_like(a, 4) # return a jax.Array
Array([4, 4, 4], dtype=int32)
try:
u.math.full_like(a, 4 * u.second) # return a Quantity
except Exception as e:
print(e)
full_like requires "fill_value" to be dimensionless when "a" is a plain array, but got fill_value with unit=s. Either pass a plain number as fill_value or wrap "a" as a Quantity.
brainunit.math.empty_like#
Return a new quantity or array with the same shape and type as a given array.
a = jnp.array([1, 2, 3])
u.math.empty_like(a) # return a jax.Array
Array([0, 0, 0], dtype=int32)
q = jnp.array([1, 2, 3]) * u.second
u.math.empty_like(q) # return a Quantity
Quantity([0 0 0], "s")
brainunit.math.ones_like#
Return a new quantity or array with the same shape and type as a given array, filled with ones.
a = jnp.array([1, 2, 3])
u.math.ones_like(a) # return a jax.Array
Array([1, 1, 1], dtype=int32)
q = jnp.array([1, 2, 3]) * u.second
u.math.ones_like(q) # return a Quantity
Quantity([1 1 1], "s")
brainunit.math.zeros_like#
Return a new quantity or array with the same shape and type as a given array, filled with zeros.
a = jnp.array([1, 2, 3])
u.math.zeros_like(a) # return a jax.Array
Array([0, 0, 0], dtype=int32)
q = jnp.array([1, 2, 3]) * u.second
u.math.zeros_like(q) # return a Quantity
Quantity([0 0 0], "s")
brainunit.math.fill_diagonal#
Fill the main diagonal of the given array of any dimensionality.
a = jnp.zeros((3, 3))
u.math.fill_diagonal(a, 4) # return a jax.Array
Array([[4., 0., 0.],
[0., 4., 0.],
[0., 0., 4.]], dtype=float32)
q = jnp.zeros((3, 3)) * u.second
u.math.fill_diagonal(q, 4 * u.second) # return a Quantity
Quantity([[4. 0. 0.]
[0. 4. 0.]
[0. 0. 4.]], "s")
Can use with unit keyword#
The functions below can be used with the unit keyword.
brainunit.math.eye#
Returns a 2-D quantity or array with ones on the diagonal and zeros elsewhere.
u.math.eye(3) # return a jax.Array
Array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
u.math.eye(3, unit=u.second) # return a Quantity
Quantity([[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]], "s")
brainunit.math.identity#
Return the identity Quantity or array.
u.math.identity(3) # return a jax.Array
Array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
u.math.identity(3, unit=u.second) # return a Quantity
Quantity([[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]], "s")
brainunit.math.tri#
Returns A quantity or an array with ones at and below the given diagonal and zeros elsewhere.
u.math.tri(3) # return a jax.Array
Array([[1., 0., 0.],
[1., 1., 0.],
[1., 1., 1.]], dtype=float32)
u.math.tri(3, unit=u.second) # return a Quantity
Quantity([[1. 0. 0.]
[1. 1. 0.]
[1. 1. 1.]], "s")
brainunit.math.diag#
Extract a diagonal or construct a diagonal array.
a = jnp.array([1, 2, 3])
u.math.diag(a) # return a jax.Array
Array([[1, 0, 0],
[0, 2, 0],
[0, 0, 3]], dtype=int32)
u.math.diag(a, unit=u.second) # return a Quantity
Quantity([[1 0 0]
[0 2 0]
[0 0 3]], "s")
brainunit.math.tril#
Lower triangle of an array.
Return a copy of a matrix with the elements above the k-th diagonal zeroed.
For quantities or arrays with ndim exceeding 2, tril will apply to the final two axes.
a = jnp.ones((3, 3))
u.math.diag(a) # return a jax.Array
Array([1., 1., 1.], dtype=float32)
u.math.diag(a, unit=u.second) # return a Quantity
Quantity([1. 1. 1.], "s")
brainunit.math.triu#
Upper triangle of an array.
Return a copy of a matrix with the elements below the k-th diagonal zeroed.
For quantities or arrays with ndim exceeding 2, triu will apply to the final two axes.
a = jnp.ones((3, 3))
u.math.tril(a) # return a jax.Array
Array([[1., 0., 0.],
[1., 1., 0.],
[1., 1., 1.]], dtype=float32)
u.math.tril(a, unit=u.second) # return a Quantity
Quantity([[1. 0. 0.]
[1. 1. 0.]
[1. 1. 1.]], "s")