CustomArray#

class saiunit.CustomArray[source]#

A custom array wrapper providing comprehensive array operations and cross-framework compatibility.

CustomArray is a mix-in class that delegates every operation to its data attribute. Subclasses must provide a data property (or attribute) that returns the underlying array. The class exposes NumPy-style methods, PyTorch-style convenience methods, and full interoperability with JAX transformations (jit, grad, vmap).

data#

The underlying array storage. Subclasses are responsible for providing this attribute (e.g. via a property backed by some internal state).

Type:

array-like

NumPy-style methods

all, any, argmax, argmin, argsort, astype, clip, copy, cumsum, cumprod, diagonal, dot, flatten, max, mean, min, nonzero, prod, ravel, repeat, reshape, round, squeeze, std, sum, swapaxes, take, tolist, trace, transpose, var

PyTorch-style methods

unsqueeze, expand, expand_as, clamp, clone, zero_, bool, int, long, half, float, double

Conversion methods

to_numpy, to_jax, numpy

Trigonometric methods

sin, cos, tan, sinh, cosh, tanh, arcsin, arccos, arctan (and in-place _ variants)

Examples

CustomArray is designed to be used as a mix-in. A minimal standalone subclass needs only a data attribute and JAX pytree registration:

import jax
import numpy as np
from saiunit import CustomArray

@jax.tree_util.register_pytree_node_class
class SimpleArray(CustomArray):
    def __init__(self, value):
        self.data = value

    def tree_flatten(self):
        return (self.data,), None

    @classmethod
    def tree_unflatten(cls, aux_data, flat_contents):
        return cls(*flat_contents)

Basic properties and arithmetic:

arr = SimpleArray(np.array([1.0, 2.0, 3.0]))
arr.shape   # (3,)
arr.ndim    # 1
arr + 10    # array([11., 12., 13.])
arr ** 2    # array([1., 4., 9.])

Statistical operations:

arr = SimpleArray(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
arr.mean()  # 3.0
arr.std()   # ~1.414
arr.sum()   # 15.0

Array manipulation:

matrix = SimpleArray(np.array([[1, 2, 3], [4, 5, 6]]))
matrix.T           # transposed (3, 2) array
matrix.reshape(6)  # array([1, 2, 3, 4, 5, 6])
matrix.flatten()   # array([1, 2, 3, 4, 5, 6])

JAX compatibility (jit, grad):

import jax
import jax.numpy as jnp

arr = SimpleArray(jnp.array([1.0, 2.0, 3.0]))

@jax.jit
def square(x):
    return x * x

square(arr)  # Array([1., 4., 9.], ...)

Notes

  • This class uses duck typing and delegates operations to the underlying data attribute.

  • In-place operations (+=, -=, etc.) modify data directly and return self.

  • Most methods return the raw underlying array type, not a new CustomArray instance.

  • Thread safety depends on the underlying array implementation.

  • JAX transformations (jit, grad, vmap) work seamlessly when the subclass is registered as a JAX pytree.

See also

numpy.ndarray

NumPy’s N-dimensional array.

jax.Array

JAX’s array implementation.

References

absolute()[source]#

alias of Array.abs

Return type:

Array | ndarray | bool | number | bool | int | float | complex | None

all(axis=None, keepdims=False)[source]#

Returns True if all elements evaluate to True.

any(axis=None, keepdims=False)[source]#

Returns True if any of the elements of a evaluate to True.

argmax(axis=None)[source]#

Return indices of the maximum datas along the given axis.

argmin(axis=None)[source]#

Return indices of the minimum datas along the given axis.

argpartition(kth, axis=-1, kind='introselect', order=None)[source]#

Returns the indices that would partition this array.

argsort(axis=-1, kind=None, order=None)[source]#

Returns the indices that would sort this array.

astype(dtype)[source]#

Copy of the array, cast to a specified type.

Parameters:

dtype (str or dtype) – Typecode or data-type to which the array is cast.

byteswap(inplace=False)[source]#

Swap the bytes of the array elements

Toggle between low-endian and big-endian data representation by returning a byteswapped array, optionally swapped in-place. Arrays of byte-strings are not swapped. The real and imaginary parts of a complex number are swapped individually.

choose(choices, mode='raise')[source]#

Use an index array to construct a new array from a set of choices.

clamp(min_data=None, max_data=None)[source]#

return the data between min_data and max_data, if min_data is None, then no lower bound, if max_data is None, then no upper bound.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | None

clamp_(min_data=None, max_data=None)[source]#

return the data between min_data and max_data, if min_data is None, then no lower bound, if max_data is None, then no upper bound.

clip(min=None, max=None)[source]#

Return an array whose datas are limited to [min, max]. One of max or min must be given.

compress(condition, axis=None)[source]#

Return selected slices of this array along given axis.

conj()[source]#

Complex-conjugate all elements.

conjugate()[source]#

Return the complex conjugate, element-wise.

copy()[source]#

Return a copy of the array.

cumprod(axis=None, dtype=None)[source]#

Return the cumulative product of the elements along the given axis.

cumsum(axis=None, dtype=None)[source]#

Return the cumulative sum of the elements along the given axis.

diagonal(offset=0, axis1=0, axis2=1)[source]#

Return specified diagonals.

dot(b)[source]#

Dot product of two arrays.

property dtype#

Variable dtype.

expand(*sizes)[source]#

Expand an array to a new shape.

Parameters:

sizes (tuple or int) – The shape of the desired array. A single integer i is interpreted as (i,).

Returns:

expanded – A readonly view on the original array with the given shape. It is typically not contiguous. Furthermore, more than one element of a expanded array may refer to a single memory location.

Return type:

Array | ndarray | bool | number | bool | int | float | complex

fill(data)[source]#

Fill the array with a scalar data.

item(*args)[source]#

Copy an element of an array to a standard Python scalar and return it.

property itemsize#

Length of one array element in bytes.

property mT#

Transpose the last two dimensions (for batched matrix operations).

max(axis=None, keepdims=False, *args, **kwargs)[source]#

Return the maximum along a given axis.

mean(axis=None, dtype=None, keepdims=False, *args, **kwargs)[source]#

Returns the average of the array elements along given axis.

min(axis=None, keepdims=False, *args, **kwargs)[source]#

Return the minimum along a given axis.

multiply(data)[source]#

See torch.multiply().

property nbytes#

Total bytes consumed by the array elements.

nonzero()[source]#

Return the indices of the elements that are non-zero.

numpy(dtype=None)[source]#

Convert to numpy.ndarray.

prod(axis=None, dtype=None, keepdims=False, initial=1, where=True)[source]#

Return the product of the array elements over the given axis.

ptp(axis=None, keepdims=False)[source]#

Peak to peak (maximum - minimum) data along a given axis.

put(indices, datas)[source]#

Replaces specified elements of an array with given datas.

Parameters:
  • indices (array_like) – Target indices, interpreted as integers.

  • datas (array_like) – Values to place in the array at target indices.

ravel(order=None)[source]#

Return a flattened array.

repeat(repeats, axis=None)[source]#

Repeat elements of an array.

reshape(*shape, order='C')[source]#

Returns an array containing the same data with a new shape.

resize(new_shape)[source]#

Change shape and size of array in-place.

round(decimals=0)[source]#

Return a with each element rounded to the given number of decimals.

property shape#

Variable shape.

sort(axis=-1, stable=True, order=None)[source]#

Sort an array in-place.

Parameters:
  • axis (int, optional) – Axis along which to sort. Default is -1, which means sort along the last axis.

  • stable (bool, optional) – Whether to use a stable sorting algorithm. The default is True.

  • order (str or list of str, optional) – When a is an array with fields defined, this argument specifies which fields to compare first, second, etc.

squeeze(axis=None)[source]#

Remove axes of length one from a.

sum(axis=None, dtype=None, keepdims=False, initial=0, where=True)[source]#

Return the sum of the array elements over the given axis.

swapaxes(axis1, axis2)[source]#

Return a view of the array with axis1 and axis2 interchanged.

take(indices, axis=None, mode=None)[source]#

Return an array formed from the elements of a at the given indices.

to_jax(dtype=None)[source]#

Convert to jax.numpy.ndarray.

to_numpy(dtype=None)[source]#

Convert to numpy.ndarray.

trace(offset=0, axis1=0, axis2=1, dtype=None)[source]#

Return the sum along diagonals of the array.

unsqueeze(dim)[source]#

equals Array.expand_dims(dim)

See brainpy.math.unsqueeze()

Return type:

Array | ndarray | bool | number | bool | int | float | complex

var(axis=None, dtype=None, ddof=0, keepdims=False)[source]#

Returns the variance of the array elements, along given axis.