Source code for saiunit.custom_array

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import operator
from typing import Any, Optional, Union, Sequence

import jax.numpy as jnp
import jax.typing
import numpy as np

from saiunit import math

ArrayLike = jax.typing.ArrayLike

__all__ = [
    'CustomArray',
]


[docs] class CustomArray: """ A custom array wrapper providing comprehensive array operations and cross-framework compatibility. ``CustomArray`` is a mix-in class that delegates every operation to its ``data`` attribute. Subclasses must provide a ``data`` property (or attribute) that returns the underlying array. The class exposes NumPy-style methods, PyTorch-style convenience methods, and full interoperability with JAX transformations (``jit``, ``grad``, ``vmap``). Attributes ---------- data : array-like The underlying array storage. Subclasses are responsible for providing this attribute (e.g. via a property backed by some internal state). Methods ------- NumPy-style methods ``all``, ``any``, ``argmax``, ``argmin``, ``argsort``, ``astype``, ``clip``, ``copy``, ``cumsum``, ``cumprod``, ``diagonal``, ``dot``, ``flatten``, ``max``, ``mean``, ``min``, ``nonzero``, ``prod``, ``ravel``, ``repeat``, ``reshape``, ``round``, ``squeeze``, ``std``, ``sum``, ``swapaxes``, ``take``, ``tolist``, ``trace``, ``transpose``, ``var`` PyTorch-style methods ``unsqueeze``, ``expand``, ``expand_as``, ``clamp``, ``clone``, ``zero_``, ``bool``, ``int``, ``long``, ``half``, ``float``, ``double`` Conversion methods ``to_numpy``, ``to_jax``, ``numpy`` Trigonometric methods ``sin``, ``cos``, ``tan``, ``sinh``, ``cosh``, ``tanh``, ``arcsin``, ``arccos``, ``arctan`` (and in-place ``_`` variants) Examples -------- ``CustomArray`` is designed to be used as a mix-in. A minimal standalone subclass needs only a ``data`` attribute and JAX pytree registration: .. code-block:: python import jax import numpy as np from saiunit import CustomArray @jax.tree_util.register_pytree_node_class class SimpleArray(CustomArray): def __init__(self, value): self.data = value def tree_flatten(self): return (self.data,), None @classmethod def tree_unflatten(cls, aux_data, flat_contents): return cls(*flat_contents) Basic properties and arithmetic: .. code-block:: python arr = SimpleArray(np.array([1.0, 2.0, 3.0])) arr.shape # (3,) arr.ndim # 1 arr + 10 # array([11., 12., 13.]) arr ** 2 # array([1., 4., 9.]) Statistical operations: .. code-block:: python arr = SimpleArray(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) arr.mean() # 3.0 arr.std() # ~1.414 arr.sum() # 15.0 Array manipulation: .. code-block:: python matrix = SimpleArray(np.array([[1, 2, 3], [4, 5, 6]])) matrix.T # transposed (3, 2) array matrix.reshape(6) # array([1, 2, 3, 4, 5, 6]) matrix.flatten() # array([1, 2, 3, 4, 5, 6]) JAX compatibility (``jit``, ``grad``): .. code-block:: python import jax import jax.numpy as jnp arr = SimpleArray(jnp.array([1.0, 2.0, 3.0])) @jax.jit def square(x): return x * x square(arr) # Array([1., 4., 9.], ...) Notes ----- - This class uses duck typing and delegates operations to the underlying ``data`` attribute. - In-place operations (``+=``, ``-=``, etc.) modify ``data`` directly and return ``self``. - Most methods return the raw underlying array type, **not** a new ``CustomArray`` instance. - Thread safety depends on the underlying array implementation. - JAX transformations (``jit``, ``grad``, ``vmap``) work seamlessly when the subclass is registered as a JAX pytree. See Also -------- numpy.ndarray : NumPy's N-dimensional array. jax.Array : JAX's array implementation. References ---------- - NumPy documentation: https://numpy.org/doc/ - JAX documentation: https://jax.readthedocs.io/ """ data: Any def __hash__(self): return hash(self.data) @property def dtype(self): """Variable dtype.""" return math.get_dtype(self.data) @property def shape(self): """Variable shape.""" return self.data.shape @property def ndim(self): return self.data.ndim @property def imag(self): return self.data.imag @property def real(self): return self.data.real @property def size(self): return self.data.size @property def T(self): return self.data.T @property def mT(self): """Transpose the last two dimensions (for batched matrix operations).""" return jnp.swapaxes(self.data, -1, -2) @property def nbytes(self): """Total bytes consumed by the array elements.""" return self.data.nbytes @property def itemsize(self): """Length of one array element in bytes.""" return self.data.itemsize def __repr__(self) -> str: return f"{self.__class__.__name__}({self.data})" def __str__(self) -> str: return str(self.data) def __format__(self, format_spec: str) -> str: return format(self.data, format_spec) def __iter__(self): """Solve the issue of DeviceArray.__iter__. Details please see JAX issues: - https://github.com/google/jax/issues/7713 - https://github.com/google/jax/pull/3821 """ for i in range(self.data.shape[0]): yield self.data[i] def __getitem__(self, index): if isinstance(index, slice) and (index == slice(None)): return self.data return self.data[index] def __setitem__(self, index, data: ArrayLike): if isinstance(data, np.ndarray): data = math.asarray(data) # update self_data = math.asarray(self.data) self.data = self_data.at[index].set(data) # ---------- # # operations # # ---------- # def __len__(self) -> int: return len(self.data) def __neg__(self): return self.data.__neg__() def __pos__(self): return self.data.__pos__() def __abs__(self): return self.data.__abs__() def __invert__(self): return self.data.__invert__() def __eq__(self, oc): return self.data == oc def __ne__(self, oc): return self.data != oc def __lt__(self, oc): return self.data < oc def __le__(self, oc): return self.data <= oc def __gt__(self, oc): return self.data > oc def __ge__(self, oc): return self.data >= oc def __add__(self, oc): return self.data + oc def __radd__(self, oc): return self.data + oc def __iadd__(self, oc): # a += b self.data = self.data + oc return self def __sub__(self, oc): return self.data - oc def __rsub__(self, oc): return oc - self.data def __isub__(self, oc): # a -= b self.data = self.data - oc return self def __mul__(self, oc): return self.data * oc def __rmul__(self, oc): return oc * self.data def __imul__(self, oc): # a *= b self.data = self.data * oc return self def __truediv__(self, oc): return self.data / oc def __rtruediv__(self, oc): return oc / self.data def __itruediv__(self, oc): # a /= b self.data = self.data / oc return self def __floordiv__(self, oc): return self.data // oc def __rfloordiv__(self, oc): return oc // self.data def __ifloordiv__(self, oc): # a //= b self.data = self.data // oc return self def __divmod__(self, oc): return self.data.__divmod__(oc) def __rdivmod__(self, oc): return self.data.__rdivmod__(oc) def __mod__(self, oc): return self.data % oc def __rmod__(self, oc): return oc % self.data def __imod__(self, oc): # a %= b self.data = self.data % oc return self def __pow__(self, oc): return self.data ** oc def __rpow__(self, oc): return oc ** self.data def __ipow__(self, oc): # a **= b self.data = self.data ** oc return self def __matmul__(self, oc): return self.data @ oc def __rmatmul__(self, oc): return oc @ self.data def __imatmul__(self, oc): # a @= b self.data = self.data @ oc return self def __and__(self, oc): return self.data & oc def __rand__(self, oc): return oc & self.data def __iand__(self, oc): # a &= b self.data = self.data & oc return self def __or__(self, oc): return self.data | oc def __ror__(self, oc): return oc | self.data def __ior__(self, oc): # a |= b self.data = self.data | oc return self def __xor__(self, oc): return self.data ^ oc def __rxor__(self, oc): return oc ^ self.data def __ixor__(self, oc): # a ^= b self.data = self.data ^ oc return self def __lshift__(self, oc): return self.data << oc def __rlshift__(self, oc): return oc << self.data def __ilshift__(self, oc): # a <<= b self.data = self.data << oc return self def __rshift__(self, oc): return self.data >> oc def __rrshift__(self, oc): return oc >> self.data def __irshift__(self, oc): # a >>= b self.data = self.data >> oc return self def __round__(self, ndigits=None): return self.data.__round__(ndigits) # ----------------------- # # NumPy methods # # ----------------------- #
[docs] def all(self, axis=None, keepdims=False): """Returns True if all elements evaluate to True.""" r = self.data.all(axis=axis, keepdims=keepdims) return r
[docs] def any(self, axis=None, keepdims=False): """Returns True if any of the elements of a evaluate to True.""" r = self.data.any(axis=axis, keepdims=keepdims) return r
[docs] def argmax(self, axis=None): """Return indices of the maximum datas along the given axis.""" return self.data.argmax(axis=axis)
[docs] def argmin(self, axis=None): """Return indices of the minimum datas along the given axis.""" return self.data.argmin(axis=axis)
[docs] def argpartition(self, kth, axis: int = -1, kind: str = 'introselect', order=None): """Returns the indices that would partition this array.""" return self.data.argpartition(kth=kth, axis=axis, kind=kind, order=order)
[docs] def argsort(self, axis=-1, kind=None, order=None): """Returns the indices that would sort this array.""" return self.data.argsort(axis=axis, kind=kind, order=order)
[docs] def astype(self, dtype): """Copy of the array, cast to a specified type. Parameters ---------- dtype : str or dtype Typecode or data-type to which the array is cast. """ if dtype is None: return self.data else: return self.data.astype(dtype)
[docs] def byteswap(self, inplace=False): """Swap the bytes of the array elements Toggle between low-endian and big-endian data representation by returning a byteswapped array, optionally swapped in-place. Arrays of byte-strings are not swapped. The real and imaginary parts of a complex number are swapped individually.""" return self.data.byteswap(inplace=inplace)
[docs] def choose(self, choices, mode='raise'): """Use an index array to construct a new array from a set of choices.""" return self.data.choose(choices=choices, mode=mode)
[docs] def clip(self, min=None, max=None): """Return an array whose datas are limited to [min, max]. One of max or min must be given.""" r = self.data.clip(min=min, max=max) return r
[docs] def compress(self, condition, axis=None): """Return selected slices of this array along given axis.""" return self.data.compress(condition=condition, axis=axis)
[docs] def conj(self): """Complex-conjugate all elements.""" return self.data.conj()
[docs] def conjugate(self): """Return the complex conjugate, element-wise.""" return self.data.conjugate()
[docs] def copy(self): """Return a copy of the array.""" return self.data.copy()
[docs] def cumprod(self, axis=None, dtype=None): """Return the cumulative product of the elements along the given axis.""" return self.data.cumprod(axis=axis, dtype=dtype)
[docs] def cumsum(self, axis=None, dtype=None): """Return the cumulative sum of the elements along the given axis.""" return self.data.cumsum(axis=axis, dtype=dtype)
[docs] def diagonal(self, offset=0, axis1=0, axis2=1): """Return specified diagonals.""" return self.data.diagonal(offset=offset, axis1=axis1, axis2=axis2)
[docs] def dot(self, b): """Dot product of two arrays.""" return self.data.dot(b)
[docs] def fill(self, data: ArrayLike): """Fill the array with a scalar data.""" self.data = math.ones_like(self.data) * data
def flatten(self): return self.data.flatten()
[docs] def item(self, *args): """Copy an element of an array to a standard Python scalar and return it.""" return self.data.item(*args)
[docs] def max(self, axis=None, keepdims=False, *args, **kwargs): """Return the maximum along a given axis.""" res = self.data.max(axis=axis, keepdims=keepdims, *args, **kwargs) return res
[docs] def mean(self, axis=None, dtype=None, keepdims=False, *args, **kwargs): """Returns the average of the array elements along given axis.""" res = self.data.mean(axis=axis, dtype=dtype, keepdims=keepdims, *args, **kwargs) return res
[docs] def min(self, axis=None, keepdims=False, *args, **kwargs): """Return the minimum along a given axis.""" res = self.data.min(axis=axis, keepdims=keepdims, *args, **kwargs) return res
[docs] def nonzero(self): """Return the indices of the elements that are non-zero.""" return tuple(a for a in self.data.nonzero())
[docs] def prod(self, axis=None, dtype=None, keepdims=False, initial=1, where=True): """Return the product of the array elements over the given axis.""" res = self.data.prod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) return res
[docs] def ptp(self, axis=None, keepdims=False): """Peak to peak (maximum - minimum) data along a given axis.""" r = self.data.ptp(axis=axis, keepdims=keepdims) return r
[docs] def put(self, indices, datas): """Replaces specified elements of an array with given datas. Parameters ---------- indices : array_like Target indices, interpreted as integers. datas : array_like Values to place in the array at target indices. """ self.__setitem__(indices, datas)
[docs] def ravel(self, order=None): """Return a flattened array.""" return self.data.ravel(order=order)
[docs] def repeat(self, repeats, axis=None): """Repeat elements of an array.""" return self.data.repeat(repeats=repeats, axis=axis)
[docs] def reshape(self, *shape, order='C'): """Returns an array containing the same data with a new shape.""" return self.data.reshape(*shape, order=order)
[docs] def resize(self, new_shape): """Change shape and size of array in-place.""" self.data = self.data.reshape(new_shape)
[docs] def round(self, decimals=0): """Return ``a`` with each element rounded to the given number of decimals.""" return self.data.round(decimals=decimals)
def searchsorted(self, v, side='left', sorter=None): return self.data.searchsorted(v=v, side=side, sorter=sorter)
[docs] def sort(self, axis=-1, stable=True, order=None): """Sort an array in-place. Parameters ---------- axis : int, optional Axis along which to sort. Default is -1, which means sort along the last axis. stable : bool, optional Whether to use a stable sorting algorithm. The default is True. order : str or list of str, optional When ``a`` is an array with fields defined, this argument specifies which fields to compare first, second, etc. """ self.data = self.data.sort(axis=axis, stable=stable, order=order)
[docs] def squeeze(self, axis=None): """Remove axes of length one from ``a``.""" return self.data.squeeze(axis=axis)
def std(self, axis=None, dtype=None, ddof=0, keepdims=False): r = self.data.std(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) return r
[docs] def sum(self, axis=None, dtype=None, keepdims=False, initial=0, where=True): """Return the sum of the array elements over the given axis.""" res = self.data.sum(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) return res
[docs] def swapaxes(self, axis1, axis2): """Return a view of the array with `axis1` and `axis2` interchanged.""" return self.data.swapaxes(axis1, axis2)
def split(self, indices_or_sections, axis=0): return [a for a in math.split(self.data, indices_or_sections, axis=axis)]
[docs] def take(self, indices, axis=None, mode=None): """Return an array formed from the elements of a at the given indices.""" return self.data.take(indices=indices, axis=axis, mode=mode)
def tolist(self): return self.data.tolist()
[docs] def trace(self, offset=0, axis1=0, axis2=1, dtype=None): """Return the sum along diagonals of the array.""" return self.data.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
def transpose(self, *axes): return self.data.transpose(*axes) def tile(self, reps): return math.tile(self.data, reps)
[docs] def var(self, axis=None, dtype=None, ddof=0, keepdims=False): """Returns the variance of the array elements, along given axis.""" r = self.data.var(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) return r
def view(self, *args, dtype=None): if len(args) == 0: if dtype is None: raise ValueError('Provide dtype or shape.') else: return self.data.view(dtype) else: if isinstance(args[0], int): # shape if dtype is not None: raise ValueError('Provide one of dtype or shape. Not both.') return self.data.reshape(*args) else: # dtype assert not isinstance(args[0], int) assert dtype is None return self.data.view(args[0]) # ------------------ # NumPy support # ------------------
[docs] def numpy(self, dtype=None): """Convert to numpy.ndarray.""" return self.to_numpy(dtype=dtype)
[docs] def to_numpy(self, dtype=None): """Convert to numpy.ndarray.""" return np.asarray(self.data, dtype=dtype)
[docs] def to_jax(self, dtype=None): """Convert to jax.numpy.ndarray.""" if dtype is None: return self.data else: return math.asarray(self.data, dtype=dtype)
def __array__(self, dtype=None): """Support ``numpy.array()`` and ``numpy.asarray()`` functions.""" return np.asarray(self.data, dtype=dtype) def __jax_array__(self): return self.data def __bool__(self) -> bool: return bool(self.data) def __float__(self): return self.data.__float__() def __int__(self): return self.data.__int__() def __complex__(self): return self.data.__complex__() def __hex__(self): assert self.ndim == 0, 'hex only works on scalar datas' return hex(self.data) # type: ignore def __oct__(self): assert self.ndim == 0, 'oct only works on scalar datas' return oct(self.data) # type: ignore def __index__(self): return operator.index(self.data) # ---------------------- # PyTorch compatibility # ----------------------
[docs] def unsqueeze(self, dim: int) -> ArrayLike: """ Array.unsqueeze(dim) -> Array, or so called Tensor equals Array.expand_dims(dim) See :func:`brainpy.math.unsqueeze` """ return math.expand_dims(self.data, dim)
def expand_dims(self, axis: Union[int, Sequence[int]]) -> ArrayLike: return math.expand_dims(self.data, axis) def expand_as(self, array: ArrayLike) -> ArrayLike: return math.broadcast_to(self.data, jnp.asarray(array).shape) def pow(self, index: int): return self.data ** index def addr( self, vec1: ArrayLike, vec2: ArrayLike, *, beta: float = 1.0, alpha: float = 1.0, ) -> Optional[ArrayLike]: r = alpha * math.outer(vec1, vec2) + beta * self.data return r def outer(self, other: ArrayLike) -> ArrayLike: return math.outer(self.data, other.data) def abs(self) -> Optional[ArrayLike]: r = math.abs(self.data) return r
[docs] def absolute(self) -> Optional[ArrayLike]: """ alias of Array.abs """ return self.abs()
def mul(self, data: ArrayLike): return self.data * data
[docs] def multiply(self, data: ArrayLike): # real signature unknown; restored from __doc__ """ multiply(data) -> Tensor See :func:`torch.multiply`. """ return self.data * data
def sin(self) -> Optional[ArrayLike]: r = math.sin(self.data) return r def sin_(self): self.data = math.sin(self.data) return self def cos_(self): self.data = math.cos(self.data) return self def cos(self) -> Optional[ArrayLike]: r = math.cos(self.data) return r def tan_(self): self.data = math.tan(self.data) return self def tan(self) -> Optional[ArrayLike]: r = math.tan(self.data) return r def sinh_(self): self.data = math.sinh(self.data) return self def sinh(self) -> Optional[ArrayLike]: r = math.sinh(self.data) return r def cosh(self) -> Optional[ArrayLike]: r = math.cosh(self.data) return r def tanh_(self): self.data = math.tanh(self.data) return self def tanh(self) -> Optional[ArrayLike]: r = math.tanh(self.data) return r def arcsin_(self): self.data = math.arcsin(self.data) return self def arcsin(self) -> Optional[ArrayLike]: r = math.arcsin(self.data) return r def arccos_(self): self.data = math.arccos(self.data) return self def arccos(self) -> Optional[ArrayLike]: r = math.arccos(self.data) return r def arctan_(self): self.data = math.arctan(self.data) return self def arctan(self) -> Optional[ArrayLike]: r = math.arctan(self.data) return r
[docs] def clamp( self, min_data: Optional[ArrayLike] = None, max_data: Optional[ArrayLike] = None, ) -> Optional[ArrayLike]: """ return the data between min_data and max_data, if min_data is None, then no lower bound, if max_data is None, then no upper bound. """ r = math.clip(self.data, min_data, max_data) return r
[docs] def clamp_( self, min_data: Optional[ArrayLike] = None, max_data: Optional[ArrayLike] = None ): """ return the data between min_data and max_data, if min_data is None, then no lower bound, if max_data is None, then no upper bound. """ self.data = math.clip(self.data, min_data, max_data) return self
def clone(self) -> ArrayLike: return self.data.copy()
[docs] def expand(self, *sizes) -> ArrayLike: """ Expand an array to a new shape. Parameters ---------- sizes : tuple or int The shape of the desired array. A single integer ``i`` is interpreted as ``(i,)``. Returns ------- expanded : Array A readonly view on the original array with the given shape. It is typically not contiguous. Furthermore, more than one element of a expanded array may refer to a single memory location. """ l_ori = len(self.shape) l_tar = len(sizes) base = l_tar - l_ori sizes_list = list(sizes) if base < 0: raise ValueError( f'the number of sizes provided ({len(sizes)}) ' f'must be greater or equal to the number of ' f'dimensions in the tensor ({len(self.shape)})' ) for i, v in enumerate(sizes[:base]): if v < 0: raise ValueError( f'The expanded size of the tensor ({v}) ' f'isn\'t allowed in a leading, non-existing dimension {i + 1}' ) for i, v in enumerate(self.shape): sizes_list[base + i] = v if sizes_list[base + i] == -1 else sizes_list[base + i] if v != 1 and sizes_list[base + i] != v: raise ValueError( f'The expanded size of the tensor ({sizes_list[base + i]}) ' f'must match the existing size ({v}) at non-singleton ' f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}' ) return math.broadcast_to(self.data, tuple(sizes_list))
def zero_(self): self.data = math.zeros_like(self.data) return self def bool(self): return math.asarray(self.data, dtype=jnp.bool_) def int(self): return math.asarray(self.data, dtype=jnp.int32) def long(self): return math.asarray(self.data, dtype=jnp.int64) def half(self): return math.asarray(self.data, dtype=jnp.float16) def float(self): return math.asarray(self.data, dtype=jnp.float32) def double(self): return math.asarray(self.data, dtype=jnp.float64) def tree_flatten(self): return (self.data,), None @classmethod def tree_unflatten(cls, aux_data, flat_contents): return cls(*flat_contents)