brainunit.lax module#

Functions that Accepting Unitless#

acos

Elementwise arc cosine: \(\mathrm{acos}(x)\).

acosh

Elementwise inverse hyperbolic cosine: \(\mathrm{acosh}(x)\).

asin

Elementwise arc sine: \(\mathrm{asin}(x)\).

asinh

Elementwise inverse hyperbolic sine: \(\mathrm{asinh}(x)\).

atan

Elementwise arc tangent: \(\mathrm{atan}(x)\).

atanh

Elementwise inverse hyperbolic tangent: \(\mathrm{atanh}(x)\).

collapse

Collapses dimensions of an array into a single dimension.

cumlogsumexp

Compute a cumulative logsumexp along axis.

bessel_i0e

Exponentially scaled modified Bessel function of order 0: \(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\).

bessel_i1e

Exponentially scaled modified Bessel function of order 1: \(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\).

digamma

Elementwise digamma: \(\psi(x)\).

lgamma

Elementwise log gamma: \(\mathrm{log}(\Gamma(x))\).

erf

Elementwise error function: \(\mathrm{erf}(x)\).

erfc

Elementwise complementary error function: \(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\).

erf_inv

Elementwise inverse error function: \(\mathrm{erf}^{-1}(x)\).

logistic

Elementwise logistic (sigmoid) function: \(\frac{1}{1 + e^{-x}}\).

atan2

Elementwise arc tangent of two variables: \(\mathrm{atan}({x \over y})\).

polygamma

Elementwise polygamma: \(\psi^{(m)}(x)\).

igamma

Elementwise regularized incomplete gamma function.

igammac

Elementwise complementary regularized incomplete gamma function.

igamma_grad_a

Elementwise derivative of the regularized incomplete gamma function.

random_gamma_grad

Elementwise derivative of samples from Gamma(a, 1).

zeta

Elementwise Hurwitz zeta function: \(\zeta(x, q)\).

betainc

Elementwise regularized incomplete beta integral.

shift_left

Elementwise left shift: \(x \ll y\).

shift_right_arithmetic

Elementwise arithmetic right shift: \(x \gg y\).

shift_right_logical

Elementwise logical right shift: \(x \gg y\).

fft

Compute a fast Fourier transform.

collapse

Collapses dimensions of an array into a single dimension.

Array Creation Functions#

zeros_like_array

Create a zero-filled array with the same shape and dtype as x.

iota

Create an iota array (integer sequence) with an optional unit.

broadcasted_iota

Broadcast an iota array into the given shape along one dimension.

Functions that Changing Unit#

rsqrt

Elementwise reciprocal square root: \(1 \over \sqrt{x}\).

div

Elementwise division: \(x \over y\).

pow

Elementwise power: \(x^y\).

integer_pow

Elementwise integer power: \(x^y\), where \(y\) is a fixed integer.

mul

Elementwise multiplication: \(x \times y\).

rem

Elementwise remainder: \(x \bmod y\).

batch_matmul

Batch matrix multiplication.

conv

Convenience wrapper around conv_general_dilated.

conv_transpose

Convenience wrapper for calculating the N-d convolution "transpose".

dot_general

General dot product/contraction operator.

Functions that Keeping Unit#

slice

Wraps XLA's Slice operator.

dynamic_slice

Wraps XLA's DynamicSlice operator.

dynamic_update_slice

Wraps XLA's DynamicUpdateSlice operator.

gather

Gather operator.

index_take

Take elements from an array at the given indices along the given axes.

slice_in_dim

Convenience wrapper around lax.slice() applying to only one dimension.

index_in_dim

Convenience wrapper around lax.slice() to perform int indexing.

dynamic_slice_ind_dim

Convenience wrapper around lax.dynamic_slice() applied to one dimension.

dynamic_index_in_dim

Convenience wrapper around dynamic_slice to perform int indexing.

dynamic_update_slice_in_dim

Convenience wrapper around dynamic_update_slice() to update a slice in a single axis.

dynamic_update_index_in_dim

Convenience wrapper around dynamic_update_slice() to update a slice of size 1 in a single axis.

sort

Wraps XLA's Sort operator.

sort_key_val

Sort keys along dimension and apply the same permutation to values.

neg

Elementwise negation: \(-x\).

cummax

Compute a cumulative maximum along axis.

cummin

Compute a cumulative minimum along axis.

cumsum

Compute a cumulative sum along axis.

scatter

Scatter-update operator.

scatter_add

Scatter-add operator.

scatter_sub

Scatter-sub operator.

scatter_mul

Scatter-multiply operator.

scatter_min

Scatter-min operator.

scatter_max

Scatter-max operator.

scatter_apply

Scatter-apply operator.

sub

Elementwise subtraction: \(x - y\).

complex

Elementwise make complex number: \(x + jy\).

pad

Applies low, high, and/or interior padding to an array.

clamp

Elementwise clamp.

convert_element_type

Elementwise cast.

bitcast_convert_type

Elementwise bitcast.

approx_max_k

Returns max k values and their indices of the operand in an approximate manner.

approx_min_k

Returns min k values and their indices of the operand in an approximate manner.

top_k

Returns top k values and their indices along the last axis of operand.

broadcast

Broadcast an array by adding new leading dimensions.

broadcast_in_dim

Broadcast an array into a target shape (XLA BroadcastInDim).

broadcast_to_rank

Add leading dimensions of size 1 to give x rank rank.

Functions that Removing Unit#

population_count

Elementwise popcount: count the number of set bits in each element.

clz

Elementwise count of leading zeros.

eq

Elementwise equals: \(x = y\).

ne

Elementwise not-equals: \(x \neq y\).

ge

Elementwise greater-than-or-equals: \(x \geq y\).

gt

Elementwise greater-than: \(x > y\).

le

Elementwise less-than-or-equals: \(x \leq y\).

lt

Elementwise less-than: \(x < y\).

Linalg Functions#

cholesky

Cholesky decomposition.

eig

Eigendecomposition of a general matrix.

eigh

Eigendecomposition of a Hermitian matrix.

hessenberg

Reduce a square matrix to upper Hessenberg form.

lu

LU decomposition with partial pivoting.

qdwh

Polar decomposition via QR-based dynamically weighted Halley iteration.

qr

QR decomposition.

schur

Schur decomposition.

svd

Singular value decomposition.

tridiagonal

Reduce a symmetric/Hermitian matrix to tridiagonal form.

householder_product

Product of elementary Householder reflectors.

triangular_solve

Triangular solve.

tridiagonal_solve

Solve a tridiagonal linear system.

Other Functions#

reduce

Reduce an array along dimensions using a computation.

reduce_precision

Reduce the precision of array elements.

broadcast_shapes

Return the shape that results from NumPy broadcasting of shapes.