Source code for saiunit._sparse_base

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import math
import numbers
from typing import Sequence, Union

import jax
import numpy as np

__all__ = [
    "SparseMatrix"
]


[docs] class SparseMatrix: """ Base class for sparse matrices in ``saiunit``. This base class defines the interface that all sparse matrix implementations in the ``saiunit`` package should follow. Concrete subclasses must implement the abstract methods defined here. Attributes ---------- data : jax.Array The non-zero values in the sparse matrix. Notes ----- This class provides ``NotImplementedError`` for most operations, requiring concrete subclasses to implement them according to their specific sparse format. Examples -------- ``SparseMatrix`` is not instantiated directly. Use a concrete subclass such as :class:`~saiunit.sparse.CSR`, :class:`~saiunit.sparse.CSC`, or :class:`~saiunit.sparse.COO`. .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.sparse as susparse >>> dense = jnp.array([[1., 0.], [0., 2.]]) >>> csr = susparse.CSR.fromdense(dense) >>> isinstance(csr, susparse.SparseMatrix) True """ data: jax.Array shape: tuple[int, ...] nse: property dtype: property __hash__ = None def __init__( self, args: tuple[jax.Array, ...], *, shape: Sequence[int] ): self.shape = tuple(int(s) for s in shape) def __len__(self): return self.shape[0] @property def size(self) -> int: return math.prod(self.shape) @property def ndim(self) -> int: return len(self.shape) def __repr__(self): name = self.__class__.__name__ try: nse = self.nse dtype = self.dtype shape = list(self.shape) except Exception: repr_ = f"{name}(<invalid>)" else: repr_ = f"{name}({dtype}{shape}, {nse=})" return repr_ @property def T(self): return self.transpose() def block_until_ready(self): for arg in self.tree_flatten()[0]: arg.block_until_ready() return self def tree_flatten(self): raise NotImplementedError(f"{self.__class__}.tree_flatten") @classmethod def tree_unflatten(cls, aux_data, children): raise NotImplementedError(f"{cls}.tree_unflatten") def transpose(self, axes=None): raise NotImplementedError(f"{self.__class__}.transpose") def todense(self): raise NotImplementedError(f"{self.__class__}.todense")
[docs] def with_data( self, data: Union[jax.Array, np.ndarray, numbers.Number, 'Quantity'] ): """ Create a new sparse matrix with the same sparsity structure but different data. Parameters ---------- data : jax.Array, numpy.ndarray, numbers.Number, or Quantity The new non-zero values. Must have the same shape, dtype, and unit as the current ``self.data``. Returns ------- SparseMatrix A new sparse matrix of the same type with the provided data. Raises ------ NotImplementedError If called on the abstract base class directly. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.sparse as susparse >>> dense = jnp.array([[1., 0.], [0., 2.]]) >>> csr = susparse.CSR.fromdense(dense) >>> new_csr = csr.with_data(csr.data * 3) >>> new_csr.todense() Array([[3., 0.], [0., 6.]], dtype=float32) """ raise NotImplementedError(f"{self.__class__}.assign_data")
[docs] def sum(self, axis: Union[int, Sequence[int]] = None): """ Sum of the elements of the sparse matrix. Parameters ---------- axis : int, sequence of int, or None, optional Axis or axes along which the sum is computed. The default (``None``) computes the sum of the flattened array. Currently only ``None`` is supported. Returns ------- jax.Array or Quantity The sum of all elements in the sparse matrix. Raises ------ NotImplementedError If ``axis`` is not ``None``. """ if axis is not None: raise NotImplementedError("CSR.sum with axis is not implemented.") return self.data.sum()
[docs] def yw_to_w( self, y_dim_arr: Union[jax.Array, np.ndarray, 'Quantity'], w_dim_arr: Union[jax.Array, np.ndarray, 'Quantity'] ) -> Union[jax.Array, 'Quantity']: """ The protocol method to convert the product of the sparse matrix and a vector to the sparse matrix data. This protocol method is primarily used in `brainscale <https://github.com/chaobrain/brainscale>`_. Args: y_dim_arr: The first vector. w_dim_arr: The second vector. Returns: The outer product of the two vectors. """ raise NotImplementedError(f"{self.__class__}.yw_to_y is not implemented.")
def __abs__(self): raise NotImplementedError(f"{self.__class__}.__abs__ is not implemented.") def __neg__(self): raise NotImplementedError(f"{self.__class__}.__neg__ is not implemented.") def __pos__(self): raise NotImplementedError(f"{self.__class__}.__pos__ is not implemented.") def __matmul__(self, other): raise NotImplementedError(f"{self.__class__}.__matmul__ is not implemented.") def __rmatmul__(self, other): raise NotImplementedError(f"{self.__class__}.__rmatmul__ is not implemented.") def __mul__(self, other): raise NotImplementedError(f"{self.__class__}.__mul__ is not implemented.") def __rmul__(self, other): raise NotImplementedError(f"{self.__class__}.__rmul__ is not implemented.") def __add__(self, other): raise NotImplementedError(f"{self.__class__}.__add__ is not implemented.") def __radd__(self, other): raise NotImplementedError(f"{self.__class__}.__radd__ is not implemented.") def __sub__(self, other): raise NotImplementedError(f"{self.__class__}.__sub__ is not implemented.") def __rsub__(self, other): raise NotImplementedError(f"{self.__class__}.__rsub__ is not implemented.") def __div__(self, other): raise NotImplementedError(f"{self.__class__}.__div__ is not implemented.") def __rdiv__(self, other): raise NotImplementedError(f"{self.__class__}.__rdiv__ is not implemented.") def __truediv__(self, other): raise NotImplementedError(f"{self.__class__}.__truediv__ is not implemented.") def __rtruediv__(self, other): raise NotImplementedError(f"{self.__class__}.__rtruediv__ is not implemented.") def __floordiv__(self, other): raise NotImplementedError(f"{self.__class__}.__floordiv__ is not implemented.") def __rfloordiv__(self, other): raise NotImplementedError(f"{self.__class__}.__rfloordiv__ is not implemented.") def __mod__(self, other): raise NotImplementedError(f"{self.__class__}.__mod__ is not implemented.") def __rmod__(self, other): raise NotImplementedError(f"{self.__class__}.__rmod__ is not implemented.") def __getitem__(self, item): raise NotImplementedError(f"{self.__class__}.__getitem__ is not implemented.")