CustomArray#
- class brainunit.CustomArray[source]#
A custom array wrapper providing comprehensive array operations and cross-framework compatibility.
CustomArrayis a mix-in class that delegates every operation to itsdataattribute. Subclasses must provide adataproperty (or attribute) that returns the underlying array. The class exposes NumPy-style methods, PyTorch-style convenience methods, and full interoperability with JAX transformations (jit,grad,vmap).- data#
The underlying array storage. Subclasses are responsible for providing this attribute (e.g. via a property backed by some internal state).
- Type:
array-like
- NumPy-style methods
all,any,argmax,argmin,argsort,astype,clip,copy,cumsum,cumprod,diagonal,dot,flatten,max,mean,min,nonzero,prod,ravel,repeat,reshape,round,squeeze,std,sum,swapaxes,take,tolist,trace,transpose,var
- PyTorch-style methods
unsqueeze,expand,expand_as,clamp,clone,zero_,bool,int,long,half,float,double
- Conversion methods
to_numpy,to_jax,numpy
- Trigonometric methods
sin,cos,tan,sinh,cosh,tanh,arcsin,arccos,arctan(and in-place_variants)
Examples
CustomArrayis designed to be used as a mix-in. A minimal standalone subclass needs only adataattribute and JAX pytree registration:import jax import numpy as np from saiunit import CustomArray @jax.tree_util.register_pytree_node_class class SimpleArray(CustomArray): def __init__(self, value): self.data = value def tree_flatten(self): return (self.data,), None @classmethod def tree_unflatten(cls, aux_data, flat_contents): return cls(*flat_contents)
Basic properties and arithmetic:
arr = SimpleArray(np.array([1.0, 2.0, 3.0])) arr.shape # (3,) arr.ndim # 1 arr + 10 # array([11., 12., 13.]) arr ** 2 # array([1., 4., 9.])
Statistical operations:
arr = SimpleArray(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) arr.mean() # 3.0 arr.std() # ~1.414 arr.sum() # 15.0
Array manipulation:
matrix = SimpleArray(np.array([[1, 2, 3], [4, 5, 6]])) matrix.T # transposed (3, 2) array matrix.reshape(6) # array([1, 2, 3, 4, 5, 6]) matrix.flatten() # array([1, 2, 3, 4, 5, 6])
JAX compatibility (
jit,grad):import jax import jax.numpy as jnp arr = SimpleArray(jnp.array([1.0, 2.0, 3.0])) @jax.jit def square(x): return x * x square(arr) # Array([1., 4., 9.], ...)
Notes
This class uses duck typing and delegates operations to the underlying
dataattribute.In-place operations (
+=,-=, etc.) modifydatadirectly and returnself.Most methods return the raw underlying array type, not a new
CustomArrayinstance.Thread safety depends on the underlying array implementation.
JAX transformations (
jit,grad,vmap) work seamlessly when the subclass is registered as a JAX pytree.
See also
numpy.ndarrayNumPy’s N-dimensional array.
jax.ArrayJAX’s array implementation.
References
NumPy documentation: https://numpy.org/doc/
JAX documentation: https://jax.readthedocs.io/
- argpartition(kth, axis=-1, kind='introselect', order=None)[source]#
Returns the indices that would partition this array.
- byteswap(inplace=False)[source]#
Swap the bytes of the array elements
Toggle between low-endian and big-endian data representation by returning a byteswapped array, optionally swapped in-place. Arrays of byte-strings are not swapped. The real and imaginary parts of a complex number are swapped individually.
- choose(choices, mode='raise')[source]#
Use an index array to construct a new array from a set of choices.
- clamp(min_data=None, max_data=None)[source]#
return the data between min_data and max_data, if min_data is None, then no lower bound, if max_data is None, then no upper bound.
- clamp_(min_data=None, max_data=None)[source]#
return the data between min_data and max_data, if min_data is None, then no lower bound, if max_data is None, then no upper bound.
- clip(min=None, max=None)[source]#
Return an array whose datas are limited to [min, max]. One of max or min must be given.
- cumprod(axis=None, dtype=None)[source]#
Return the cumulative product of the elements along the given axis.
- cumsum(axis=None, dtype=None)[source]#
Return the cumulative sum of the elements along the given axis.
- property dtype#
Variable dtype.
- expand(*sizes)[source]#
Expand an array to a new shape.
- Parameters:
sizes (tuple or int) – The shape of the desired array. A single integer
iis interpreted as(i,).- Returns:
expanded – A readonly view on the original array with the given shape. It is typically not contiguous. Furthermore, more than one element of a expanded array may refer to a single memory location.
- Return type:
Array|ndarray|bool|number|bool|int|float|complex
- property itemsize#
Length of one array element in bytes.
- property mT#
Transpose the last two dimensions (for batched matrix operations).
- mean(axis=None, dtype=None, keepdims=False, *args, **kwargs)[source]#
Returns the average of the array elements along given axis.
- property nbytes#
Total bytes consumed by the array elements.
- prod(axis=None, dtype=None, keepdims=False, initial=1, where=True)[source]#
Return the product of the array elements over the given axis.
- put(indices, datas)[source]#
Replaces specified elements of an array with given datas.
- Parameters:
indices (array_like) – Target indices, interpreted as integers.
datas (array_like) – Values to place in the array at target indices.
- property shape#
Variable shape.
- sort(axis=-1, stable=True, order=None)[source]#
Sort an array in-place.
- Parameters:
axis (int, optional) – Axis along which to sort. Default is -1, which means sort along the last axis.
stable (bool, optional) – Whether to use a stable sorting algorithm. The default is True.
order (str or list of str, optional) – When
ais an array with fields defined, this argument specifies which fields to compare first, second, etc.
- sum(axis=None, dtype=None, keepdims=False, initial=0, where=True)[source]#
Return the sum of the array elements over the given axis.
- take(indices, axis=None, mode=None)[source]#
Return an array formed from the elements of a at the given indices.