# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import operator
from typing import Any, Tuple, Sequence, NamedTuple
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax import tree_util
from jax.experimental.sparse import JAXSparse, coo_todense_p, coo_fromdense_p, coo_matmat_p, coo_matvec_p
from saiunit._base_getters import (
get_mantissa,
get_unit,
maybe_decimal,
split_mantissa_unit,
)
from saiunit._base_quantity import Quantity
from saiunit._compatible_import import concrete_or_error
from saiunit._sparse_base import SparseMatrix
from saiunit.math._fun_array_creation import asarray
from saiunit.math._fun_keep_unit import promote_dtypes
__all__ = [
'COO', 'coo_todense', 'coo_fromdense',
]
Dtype = Any
Shape = tuple[int, ...]
def _const_like(x: jax.Array, value: int) -> jax.Array:
return jnp.asarray(value, dtype=x.dtype)
class COOInfo(NamedTuple):
shape: Shape
rows_sorted: bool = False
cols_sorted: bool = False
[docs]
@tree_util.register_pytree_node_class
class COO(SparseMatrix):
"""
Unit-aware Coordinate (COO) sparse matrix.
Stores a 2-D sparse matrix in COO (coordinate) format with optional
physical-unit support via :class:`~saiunit.Quantity`.
Parameters
----------
args : tuple of (data, row, col)
``data`` contains the non-zero values (``jax.Array`` or ``Quantity``),
``row`` contains the row indices, and ``col`` contains the column
indices.
shape : tuple of int
The ``(nrows, ncols)`` shape of the matrix.
rows_sorted : bool, optional
Whether the row indices are sorted. Default is ``False``.
cols_sorted : bool, optional
Whether the column indices are sorted. Default is ``False``.
Attributes
----------
data : jax.Array or Quantity
Non-zero values of shape ``(nse,)``.
row : jax.Array
Row indices of shape ``(nse,)``.
col : jax.Array
Column indices of shape ``(nse,)``.
shape : tuple of int
Shape of the matrix ``(nrows, ncols)``.
nse : int
Number of stored elements.
dtype : dtype
Data type of the stored values.
See Also
--------
CSR : Unit-aware Compressed Sparse Row matrix.
CSC : Unit-aware Compressed Sparse Column matrix.
coo_fromdense : Create a COO matrix from a dense array.
coo_todense : Convert a COO matrix to a dense array.
Notes
-----
This class has minimal compatibility with JAX transforms such as
``grad`` and ``jit``, and offers limited functionality compared to
:class:`jax.experimental.sparse.BCOO`. Additionally, there are known
failures when ``nse`` is larger than the true number of non-zeros in the
represented matrix.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0., 2.], [0., 0., 3.]])
>>> coo = susparse.COO.fromdense(dense)
>>> coo.shape
(2, 3)
>>> coo.todense()
Array([[1., 0., 2.],
[0., 0., 3.]], dtype=float32)
"""
data: jax.Array
row: jax.Array
col: jax.Array
shape: tuple[int, int]
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)
_info = property(
lambda self: COOInfo(
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted)
)
_bufs = property(lambda self: (self.data, self.row, self.col))
_rows_sorted: bool
_cols_sorted: bool
def __init__(
self,
args: Tuple[jax.Array | Quantity, jax.Array, jax.Array],
*,
shape: Shape,
rows_sorted: bool = False,
cols_sorted: bool = False
):
self.data, self.row, self.col = map(asarray, args)
self._rows_sorted = rows_sorted
self._cols_sorted = cols_sorted
super().__init__(args, shape=shape)
@classmethod
def fromdense(
cls,
mat: jax.Array,
*,
nse: int | None = None,
index_dtype: jax.typing.DTypeLike = np.int32
) -> COO:
return coo_fromdense(mat, nse=nse, index_dtype=index_dtype)
def _sort_indices(self) -> COO:
"""Return a copy of the COO matrix with sorted indices.
The matrix is sorted by row indices and column indices per row.
If self._rows_sorted is True, this returns ``self`` without a copy.
"""
# TODO(jakevdp): would be benefit from lowering this to cusparse sort_rows utility?
if self._rows_sorted:
return self
data, unit = split_mantissa_unit(self.data)
row, col, data = lax.sort((self.row, self.col, data), num_keys=2)
return self.__class__(
(
maybe_decimal(Quantity(data, unit=unit)),
row,
col
),
shape=self.shape,
rows_sorted=True
)
@classmethod
def _empty(
cls,
shape: Sequence[int],
*,
dtype: jax.typing.DTypeLike | None = None,
index_dtype: jax.typing.DTypeLike = 'int32'
) -> COO:
"""Create an empty COO instance. Public method is sparse.empty()."""
shape = tuple(shape)
if len(shape) != 2:
raise ValueError(f"COO must have ndim=2; got {shape=}")
data = jnp.empty(0, dtype)
row = col = jnp.empty(0, index_dtype)
return cls(
(data, row, col),
shape=shape,
rows_sorted=True,
cols_sorted=True
)
@classmethod
def _eye(
cls,
N: int,
M: int,
k: int,
*,
dtype: jax.typing.DTypeLike | None = None,
index_dtype: jax.typing.DTypeLike = 'int32'
) -> COO:
if k > 0:
diag_size = min(N, M - k)
else:
diag_size = min(N + k, M)
if diag_size <= 0:
# if k is out of range, return an empty matrix.
return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype)
data = jnp.ones(diag_size, dtype=dtype)
idx = jnp.arange(diag_size, dtype=index_dtype)
zero = _const_like(idx, 0)
k = _const_like(idx, k)
row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k))
col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k))
return cls(
(data, row, col),
shape=(N, M),
rows_sorted=True,
cols_sorted=True
)
[docs]
def with_data(self, data: jax.Array | Quantity) -> COO:
"""
Create a new COO matrix with the same sparsity structure but different data.
Parameters
----------
data : jax.Array or Quantity
New non-zero values. Must have the same shape, dtype, and unit as
the current ``self.data``.
Returns
-------
COO
A new COO matrix sharing the same ``row`` and ``col`` indices but
holding the provided ``data``.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0.], [0., 2.]])
>>> coo = susparse.COO.fromdense(dense)
>>> new_coo = coo.with_data(coo.data * 5)
>>> new_coo.todense()
Array([[ 5., 0.],
[ 0., 10.]], dtype=float32)
"""
if data.shape != self.data.shape:
raise ValueError(f"Shape mismatch: expected {self.data.shape}, got {data.shape}")
if data.dtype != self.data.dtype:
raise ValueError(f"Dtype mismatch: expected {self.data.dtype}, got {data.dtype}")
if get_unit(data) != get_unit(self.data):
raise ValueError(f"Unit mismatch: expected {get_unit(self.data)}, got {get_unit(data)}")
return COO(
(data, self.row, self.col),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
[docs]
def todense(self) -> jax.Array:
"""
Convert this COO matrix to a dense array.
Returns
-------
jax.Array or Quantity
Dense 2-D array equivalent to this sparse matrix.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[0., 3.], [4., 0.]])
>>> coo = susparse.COO.fromdense(dense)
>>> coo.todense()
Array([[0., 3.],
[4., 0.]], dtype=float32)
"""
return coo_todense(self)
def transpose(self, axes: Tuple[int, ...] | None = None) -> COO:
if axes is not None:
raise NotImplementedError("axes argument to transpose()")
return COO(
(self.data, self.col, self.row),
shape=self.shape[::-1],
rows_sorted=self._cols_sorted,
cols_sorted=self._rows_sorted
)
def tree_flatten(self) -> Tuple[
Tuple[jax.Array | Quantity,], dict[str, Any]
]:
aux = self._info._asdict()
aux['row'] = self.row
aux['col'] = self.col
return (self.data,), aux
@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, = children
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted', 'row', 'col'}:
raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}")
obj.shape = aux_data['shape']
obj._rows_sorted = aux_data['rows_sorted']
obj._cols_sorted = aux_data['cols_sorted']
obj.row = aux_data['row']
obj.col = aux_data['col']
return obj
def __abs__(self):
return COO(
(self.data.__abs__(), self.row, self.col),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
def __neg__(self):
return COO(
(-self.data, self.row, self.col),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
def __pos__(self):
return COO(
(self.data.__pos__(), self.row, self.col),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
def _binary_op(self, other, op):
if isinstance(other, COO):
if id(self.row) == id(other.row) and id(self.col) == id(other.col):
return COO(
(
op(self.data, other.data),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
other = asarray(other)
if other.size == 1:
return COO(
(
op(self.data, other),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
elif other.ndim == 2 and other.shape == self.shape:
other = other[self.row, self.col]
return COO(
(
op(self.data, other),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
def _binary_rop(self, other, op):
if isinstance(other, COO):
if id(self.row) == id(other.row) and id(self.col) == id(other.col):
return COO(
(
op(other.data, self.data),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
other = asarray(other)
if other.size == 1:
return COO(
(
op(other, self.data),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
elif other.ndim == 2 and other.shape == self.shape:
other = other[self.row, self.col]
return COO(
(
op(other, self.data),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
def __mul__(self, other: jax.Array | Quantity) -> COO:
return self._binary_op(other, operator.mul)
def __rmul__(self, other: jax.Array | Quantity) -> COO:
return self._binary_rop(other, operator.mul)
def __div__(self, other: jax.Array | Quantity) -> COO:
return self._binary_op(other, operator.truediv)
def __rdiv__(self, other: jax.Array | Quantity) -> COO:
return self._binary_rop(other, operator.truediv)
def __truediv__(self, other) -> COO:
return self.__div__(other)
def __rtruediv__(self, other) -> COO:
return self.__rdiv__(other)
def __add__(self, other) -> COO:
return self._binary_op(other, operator.add)
def __radd__(self, other) -> COO:
return self._binary_rop(other, operator.add)
def __sub__(self, other) -> COO:
return self._binary_op(other, operator.sub)
def __rsub__(self, other) -> COO:
return self._binary_rop(other, operator.sub)
def __mod__(self, other) -> COO:
return self._binary_op(other, operator.mod)
def __rmod__(self, other) -> COO:
return self._binary_rop(other, operator.mod)
def __matmul__(
self, other: jax.typing.ArrayLike
) -> jax.Array | Quantity:
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError("matmul between two sparse objects.")
other = asarray(other)
data, other = promote_dtypes(self.data, other)
self_promoted = COO(
(
data,
self.row,
self.col
),
**self._info._asdict()
)
if other.ndim == 1:
return coo_matvec(self_promoted, other)
elif other.ndim == 2:
return coo_matmat(self_promoted, other)
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
def __rmatmul__(
self,
other: jax.typing.ArrayLike
) -> jax.Array | Quantity:
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError("matmul between two sparse objects.")
other = asarray(other)
data, other = promote_dtypes(self.data, other)
self_promoted = COO(
(
data,
self.row,
self.col
),
**self._info._asdict()
)
if other.ndim == 1:
return coo_matvec(self_promoted, other, transpose=True)
elif other.ndim == 2:
other = other.T
return coo_matmat(self_promoted, other, transpose=True).T
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
[docs]
def coo_todense(mat: COO) -> jax.Array | Quantity:
"""
Convert a COO-format sparse matrix to a dense matrix.
Parameters
----------
mat : COO
The COO sparse matrix to convert.
Returns
-------
jax.Array or Quantity
Dense 2-D array equivalent to ``mat``.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[5., 0.], [0., 6.]])
>>> coo = susparse.coo_fromdense(dense)
>>> susparse.coo_todense(coo)
Array([[5., 0.],
[0., 6.]], dtype=float32)
"""
return _coo_todense(mat.data, mat.row, mat.col, spinfo=mat._info)
[docs]
def coo_fromdense(
mat: jax.Array | Quantity,
*,
nse: int | None = None,
index_dtype: jax.typing.DTypeLike = jnp.int32
) -> COO:
"""
Create a COO-format sparse matrix from a dense matrix.
Parameters
----------
mat : jax.Array or Quantity
Dense 2-D array to be converted to COO format.
nse : int or None, optional
Number of specified (non-zero) entries in ``mat``. If ``None``
(default), it is computed automatically from the input matrix.
index_dtype : dtype, optional
Data type for the sparse index arrays. Default is ``jnp.int32``.
Returns
-------
COO
The COO representation of the input matrix.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0., 0.], [0., 2., 3.]])
>>> coo = susparse.coo_fromdense(dense)
>>> coo.shape
(2, 3)
>>> coo.todense()
Array([[1., 0., 0.],
[0., 2., 3.]], dtype=float32)
"""
if nse is None:
nse = int((get_mantissa(mat) != 0).sum())
nse_int = concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
return COO(
_coo_fromdense(mat, nse=nse_int, index_dtype=index_dtype),
shape=mat.shape,
rows_sorted=True
)
def _coo_todense(
data: jax.Array | Quantity,
row: jax.Array,
col: jax.Array,
*,
spinfo: COOInfo
) -> jax.Array | Quantity:
"""Convert CSR-format sparse matrix to a dense matrix.
Args:
data : array of shape ``(nse,)``.
row : array of shape ``(nse,)``
col : array of shape ``(nse,)`` and dtype ``row.dtype``
spinfo : COOInfo object containing matrix metadata
Returns:
mat : array with specified shape and dtype matching ``data``
"""
data, unit = split_mantissa_unit(data)
r = coo_todense_p.bind(data, row, col, spinfo=spinfo)
return maybe_decimal(r * unit)
def _coo_fromdense(
mat: jax.Array | Quantity,
*,
nse: int,
index_dtype: jax.typing.DTypeLike = jnp.int32
) -> Tuple[jax.Array | Quantity, jax.Array, jax.Array]:
"""Create COO-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to COO.
nse : number of specified entries in ``mat``
index_dtype : dtype of sparse indices
Returns:
data : array of shape ``(nse,)`` and dtype ``mat.dtype``
row : array of shape ``(nse,)`` and dtype ``index_dtype``
col : array of shape ``(nse,)`` and dtype ``index_dtype``
"""
mat = asarray(mat)
mat, unit = split_mantissa_unit(mat)
nse = concrete_or_error(operator.index, nse, "nse argument of coo_fromdense()")
r = coo_fromdense_p.bind(mat, nse=nse, index_dtype=index_dtype)
if unit.is_unitless:
return r
return r[0] * unit, r[1], r[2]
def coo_matvec(
mat: COO,
v: jax.Array | Quantity,
transpose: bool = False
) -> jax.Array | Quantity:
"""Product of COO sparse matrix and a dense vector.
Args:
mat : COO matrix
v : one-dimensional array of size ``(shape[0] if transpose else shape[1],)`` and
dtype ``mat.dtype``
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
y : array of shape ``(mat.shape[1] if transpose else mat.shape[0],)`` representing
the matrix vector product.
"""
data, row, col = mat._bufs
return _coo_matvec(data, row, col, v, spinfo=mat._info, transpose=transpose)
def _coo_matvec(
data: jax.Array | Quantity,
row: jax.Array,
col: jax.Array,
v: jax.Array | Quantity,
*,
spinfo: COOInfo,
transpose: bool = False
) -> jax.Array | Quantity:
"""Product of COO sparse matrix and a dense vector.
Args:
data : array of shape ``(nse,)``.
row : array of shape ``(nse,)``
col : array of shape ``(nse,)`` and dtype ``row.dtype``
v : array of shape ``(shape[0] if transpose else shape[1],)`` and
dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
data, unita = split_mantissa_unit(data)
v, unitv = split_mantissa_unit(v)
r = coo_matvec_p.bind(data, row, col, v, spinfo=spinfo, transpose=transpose)
return maybe_decimal(r * unita * unitv)
def coo_matmat(
mat: COO,
B: jax.Array | Quantity,
*,
transpose: bool = False
) -> jax.Array | Quantity:
"""Product of COO sparse matrix and a dense matrix.
Args:
mat : COO matrix
B : array of shape ``(mat.shape[0] if transpose else mat.shape[1], cols)`` and
dtype ``mat.dtype``
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
C : array of shape ``(mat.shape[1] if transpose else mat.shape[0], cols)``
representing the matrix vector product.
"""
data, row, col = mat._bufs
return _coo_matmat(data, row, col, B, spinfo=mat._info, transpose=transpose)
def _coo_matmat(
data: jax.Array | Quantity,
row: jax.Array,
col: jax.Array,
B: jax.Array | Quantity,
*,
spinfo: COOInfo,
transpose: bool = False
) -> jax.Array:
"""Product of COO sparse matrix and a dense matrix.
Args:
data : array of shape ``(nse,)``.
row : array of shape ``(nse,)``
col : array of shape ``(nse,)`` and dtype ``row.dtype``
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix vector product.
"""
data, unita = split_mantissa_unit(data)
B, unitb = split_mantissa_unit(B)
res = coo_matmat_p.bind(data, row, col, B, spinfo=spinfo, transpose=transpose)
return maybe_decimal(res * unita * unitb)