# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import operator
from typing import Tuple, Union
import jax
import jax.numpy as jnp
import numpy as np
from jax import tree_util
from jax.experimental.sparse import JAXSparse, csr_fromdense_p, csr_todense_p, csr_matvec_p, csr_matmat_p
from saiunit._base_getters import (
get_mantissa,
get_unit,
maybe_decimal,
split_mantissa_unit,
)
from saiunit._base_quantity import Quantity
from saiunit._compatible_import import concrete_or_error
from saiunit._sparse_base import SparseMatrix
from saiunit.math._fun_array_creation import asarray
from saiunit.math._fun_keep_unit import promote_dtypes
__all__ = [
'CSR', 'CSC',
'csr_fromdense', 'csr_todense',
'csc_fromdense', 'csc_todense',
]
Shape = tuple[int, ...]
def _const_like(x: jax.Array, value: int) -> jax.Array:
return jnp.asarray(value, dtype=x.dtype)
[docs]
@tree_util.register_pytree_node_class
class CSR(SparseMatrix):
"""
Unit-aware Compressed Sparse Row (CSR) matrix.
Stores a 2-D sparse matrix in CSR format with optional physical-unit
support via :class:`~saiunit.Quantity`.
Parameters
----------
args : tuple of (data, indices, indptr)
``data`` contains the non-zero values (``jax.Array`` or ``Quantity``),
``indices`` contains the column indices, and ``indptr`` contains the
row pointer array.
shape : tuple of int
The ``(nrows, ncols)`` shape of the matrix.
Attributes
----------
data : jax.Array or Quantity
Non-zero values of shape ``(nse,)``.
indices : jax.Array
Column indices of shape ``(nse,)``.
indptr : jax.Array
Row pointer array of shape ``(nrows + 1,)``.
shape : tuple of int
Shape of the matrix ``(nrows, ncols)``.
nse : int
Number of stored elements.
dtype : dtype
Data type of the stored values.
See Also
--------
CSC : Unit-aware Compressed Sparse Column matrix.
csr_fromdense : Create a CSR matrix from a dense array.
csr_todense : Convert a CSR matrix to a dense array.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0., 2.], [0., 0., 3.]])
>>> csr = susparse.CSR.fromdense(dense)
>>> csr.shape
(2, 3)
>>> csr.todense()
Array([[1., 0., 2.],
[0., 0., 3.]], dtype=float32)
"""
data: jax.Array | Quantity
indices: jax.Array
indptr: jax.Array
shape: tuple[int, int]
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)
_bufs = property(lambda self: (self.data, self.indices, self.indptr))
def __init__(self, args, *, shape):
self.data, self.indices, self.indptr = map(asarray, args)
super().__init__(args, shape=shape)
@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
if nse is None:
nse = (get_mantissa(mat) != 0).sum()
return csr_fromdense(mat, nse=nse, index_dtype=index_dtype)
@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
"""Create an empty CSR instance. Public method is sparse.empty()."""
shape = tuple(shape)
if len(shape) != 2:
raise ValueError(f"CSR must have ndim=2; got {shape=}")
data = jnp.empty(0, dtype)
indices = jnp.empty(0, index_dtype)
indptr = jnp.zeros(shape[0] + 1, index_dtype)
return cls((data, indices, indptr), shape=shape)
@classmethod
def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'):
if k > 0:
diag_size = min(N, M - k)
else:
diag_size = min(N + k, M)
if diag_size <= 0:
# if k is out of range, return an empty matrix.
return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype)
data = jnp.ones(diag_size, dtype=dtype)
idx = jnp.arange(diag_size, dtype=index_dtype)
zero = _const_like(idx, 0)
k = _const_like(idx, k)
col = jax.lax.add(idx, jax.lax.cond(k <= 0, lambda: zero, lambda: k))
indices = col.astype(index_dtype)
# TODO(jakevdp): this can be done more efficiently.
row = jax.lax.sub(idx, jax.lax.cond(k >= 0, lambda: zero, lambda: k))
indptr = jnp.zeros(N + 1, dtype=index_dtype).at[1:].set(
jnp.cumsum(jnp.bincount(row, length=N).astype(index_dtype)))
return cls((data, indices, indptr), shape=(N, M))
[docs]
def with_data(self, data: jax.Array | Quantity) -> CSR:
"""
Create a new CSR matrix with the same sparsity structure but different data.
Parameters
----------
data : jax.Array or Quantity
New non-zero values. Must have the same shape, dtype, and unit as
the current ``self.data``.
Returns
-------
CSR
A new CSR matrix sharing the same ``indices`` and ``indptr`` but
holding the provided ``data``.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0.], [0., 2.]])
>>> csr = susparse.CSR.fromdense(dense)
>>> new_csr = csr.with_data(csr.data * 5)
>>> new_csr.todense()
Array([[ 5., 0.],
[ 0., 10.]], dtype=float32)
"""
if data.shape != self.data.shape:
raise ValueError(f"Shape mismatch: expected {self.data.shape}, got {data.shape}")
if data.dtype != self.data.dtype:
raise ValueError(f"Dtype mismatch: expected {self.data.dtype}, got {data.dtype}")
if get_unit(data) != get_unit(self.data):
raise ValueError(f"Unit mismatch: expected {get_unit(self.data)}, got {get_unit(data)}")
return self.__class__((data, self.indices, self.indptr), shape=self.shape)
[docs]
def todense(self):
"""
Convert this CSR matrix to a dense array.
Returns
-------
jax.Array or Quantity
Dense 2-D array equivalent to this sparse matrix.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[0., 3.], [4., 0.]])
>>> csr = susparse.CSR.fromdense(dense)
>>> csr.todense()
Array([[0., 3.],
[4., 0.]], dtype=float32)
"""
return csr_todense(self)
def transpose(self, axes=None):
if axes is not None:
raise NotImplementedError("axes argument to transpose()")
return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
def __abs__(self):
return CSR((abs(self.data), self.indices, self.indptr), shape=self.shape)
def __neg__(self):
return CSR((-self.data, self.indices, self.indptr), shape=self.shape)
def __pos__(self):
return CSR((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)
def _binary_op(self, other, op):
if isinstance(other, CSR):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSR(
(op(self.data, other.data),
self.indices,
self.indptr),
shape=self.shape
)
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
other = asarray(other)
if other.size == 1:
return CSR(
(op(self.data, other), self.indices, self.indptr),
shape=self.shape
)
elif other.ndim == 2 and other.shape == self.shape:
rows, cols = _csr_to_coo(self.indices, self.indptr)
other = other[rows, cols]
return CSR(
(op(self.data, other),
self.indices,
self.indptr),
shape=self.shape
)
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
def _binary_rop(self, other, op):
if isinstance(other, CSR):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSR(
(op(other.data, self.data),
self.indices,
self.indptr),
shape=self.shape
)
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
other = asarray(other)
if other.size == 1:
return CSR(
(op(other, self.data),
self.indices,
self.indptr),
shape=self.shape
)
elif other.ndim == 2 and other.shape == self.shape:
rows, cols = _csr_to_coo(self.indices, self.indptr)
other = other[rows, cols]
return CSR(
(op(other, self.data),
self.indices,
self.indptr),
shape=self.shape
)
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
def __mul__(self, other: jax.Array | Quantity) -> CSR:
return self._binary_op(other, operator.mul)
def __rmul__(self, other: jax.Array | Quantity) -> CSR:
return self._binary_rop(other, operator.mul)
def __div__(self, other: jax.Array | Quantity) -> CSR:
return self._binary_op(other, operator.truediv)
def __rdiv__(self, other: jax.Array | Quantity) -> CSR:
return self._binary_rop(other, operator.truediv)
def __truediv__(self, other) -> CSR:
return self.__div__(other)
def __rtruediv__(self, other) -> CSR:
return self.__rdiv__(other)
def __add__(self, other) -> CSR:
return self._binary_op(other, operator.add)
def __radd__(self, other) -> CSR:
return self._binary_rop(other, operator.add)
def __sub__(self, other) -> CSR:
return self._binary_op(other, operator.sub)
def __rsub__(self, other) -> CSR:
return self._binary_rop(other, operator.sub)
def __mod__(self, other) -> CSR:
return self._binary_op(other, operator.mod)
def __rmod__(self, other) -> CSR:
return self._binary_rop(other, operator.mod)
def __matmul__(self, other):
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError("matmul between two sparse objects.")
other = asarray(other)
data, other = promote_dtypes(self.data, other)
if other.ndim == 1:
return _csr_matvec(
data,
self.indices,
self.indptr,
other,
shape=self.shape
)
elif other.ndim == 2:
return _csr_matmat(
data,
self.indices,
self.indptr,
other,
shape=self.shape
)
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
def __rmatmul__(self, other):
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError("matmul between two sparse objects.")
other = asarray(other)
data, other = promote_dtypes(self.data, other)
if other.ndim == 1:
return _csr_matvec(
data,
self.indices,
self.indptr,
other,
shape=self.shape,
transpose=True
)
elif other.ndim == 2:
other = other.T
r = _csr_matmat(
data,
self.indices,
self.indptr,
other,
shape=self.shape,
transpose=True
)
return r.T
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
def tree_flatten(self):
return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr}
@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, = children
if aux_data.keys() != {'shape', 'indices', 'indptr'}:
raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}")
obj.shape = aux_data['shape']
obj.indices = aux_data['indices']
obj.indptr = aux_data['indptr']
return obj
[docs]
@tree_util.register_pytree_node_class
class CSC(SparseMatrix):
"""
Unit-aware Compressed Sparse Column (CSC) matrix.
Stores a 2-D sparse matrix in CSC format with optional physical-unit
support via :class:`~saiunit.Quantity`.
Parameters
----------
args : tuple of (data, indices, indptr)
``data`` contains the non-zero values (``jax.Array`` or ``Quantity``),
``indices`` contains the row indices, and ``indptr`` contains the
column pointer array.
shape : tuple of int
The ``(nrows, ncols)`` shape of the matrix.
Attributes
----------
data : jax.Array or Quantity
Non-zero values of shape ``(nse,)``.
indices : jax.Array
Row indices of shape ``(nse,)``.
indptr : jax.Array
Column pointer array of shape ``(ncols + 1,)``.
shape : tuple of int
Shape of the matrix ``(nrows, ncols)``.
nse : int
Number of stored elements.
dtype : dtype
Data type of the stored values.
See Also
--------
CSR : Unit-aware Compressed Sparse Row matrix.
csc_fromdense : Create a CSC matrix from a dense array.
csc_todense : Convert a CSC matrix to a dense array.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0., 2.], [0., 0., 3.]])
>>> csc = susparse.CSC.fromdense(dense)
>>> csc.shape
(2, 3)
>>> csc.todense()
Array([[1., 0., 2.],
[0., 0., 3.]], dtype=float32)
"""
data: jax.Array
indices: jax.Array
indptr: jax.Array
shape: tuple[int, int]
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)
def __init__(self, args, *, shape):
self.data, self.indices, self.indptr = map(asarray, args)
super().__init__(args, shape=shape)
@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
if nse is None:
nse = (get_mantissa(mat) != 0).sum()
return csr_fromdense(mat.T, nse=nse, index_dtype=index_dtype).T
@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
"""Create an empty CSC instance. Public method is sparse.empty()."""
shape = tuple(shape)
if len(shape) != 2:
raise ValueError(f"CSC must have ndim=2; got {shape=}")
data = jnp.empty(0, dtype)
indices = jnp.empty(0, index_dtype)
indptr = jnp.zeros(shape[1] + 1, index_dtype)
return cls((data, indices, indptr), shape=shape)
@classmethod
def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'):
return CSR._eye(M, N, -k, dtype=dtype, index_dtype=index_dtype).T
[docs]
def with_data(self, data: jax.Array | Quantity) -> CSC:
"""
Create a new CSC matrix with the same sparsity structure but different data.
Parameters
----------
data : jax.Array or Quantity
New non-zero values. Must have the same shape, dtype, and unit as
the current ``self.data``.
Returns
-------
CSC
A new CSC matrix sharing the same ``indices`` and ``indptr`` but
holding the provided ``data``.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0.], [0., 2.]])
>>> csc = susparse.CSC.fromdense(dense)
>>> new_csc = csc.with_data(csc.data * 5)
>>> new_csc.todense()
Array([[ 5., 0.],
[ 0., 10.]], dtype=float32)
"""
if data.shape != self.data.shape:
raise ValueError(f"Shape mismatch: expected {self.data.shape}, got {data.shape}")
if data.dtype != self.data.dtype:
raise ValueError(f"Dtype mismatch: expected {self.data.dtype}, got {data.dtype}")
if get_unit(data) != get_unit(self.data):
raise ValueError(f"Unit mismatch: expected {get_unit(self.data)}, got {get_unit(data)}")
return CSC((data, self.indices, self.indptr), shape=self.shape)
[docs]
def todense(self):
"""
Convert this CSC matrix to a dense array.
Returns
-------
jax.Array or Quantity
Dense 2-D array equivalent to this sparse matrix.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[0., 3.], [4., 0.]])
>>> csc = susparse.CSC.fromdense(dense)
>>> csc.todense()
Array([[0., 3.],
[4., 0.]], dtype=float32)
"""
return csr_todense(self.T).T
[docs]
def transpose(self, axes=None):
"""
Return the transpose of this CSC matrix as a CSR matrix.
Parameters
----------
axes : None, optional
Not supported. Must be ``None``.
Returns
-------
CSR
The transposed matrix in CSR format.
Raises
------
NotImplementedError
If ``axes`` is not ``None``.
"""
if axes is not None:
raise NotImplementedError("axes argument to transpose()")
return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1])
def __abs__(self):
return CSC((abs(self.data), self.indices, self.indptr), shape=self.shape)
def __neg__(self):
return CSC((-self.data, self.indices, self.indptr), shape=self.shape)
def __pos__(self):
return CSC((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)
def _binary_op(self, other, op):
if isinstance(other, CSC):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSC(
(op(self.data, other.data),
self.indices,
self.indptr),
shape=self.shape
)
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
other = asarray(other)
if other.size == 1:
return CSC(
(op(self.data, other),
self.indices,
self.indptr),
shape=self.shape
)
elif other.ndim == 2 and other.shape == self.shape:
cols, rows = _csr_to_coo(self.indices, self.indptr)
other = other[rows, cols]
return CSC(
(op(self.data, other),
self.indices,
self.indptr),
shape=self.shape
)
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
def _binary_rop(self, other, op):
if isinstance(other, CSC):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSC(
(op(other.data, self.data),
self.indices,
self.indptr),
shape=self.shape
)
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
other = asarray(other)
if other.size == 1:
return CSC(
(op(other, self.data),
self.indices,
self.indptr),
shape=self.shape
)
elif other.ndim == 2 and other.shape == self.shape:
cols, rows = _csr_to_coo(self.indices, self.indptr)
other = other[rows, cols]
return CSC(
(op(other, self.data),
self.indices,
self.indptr),
shape=self.shape
)
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
def __mul__(self, other: jax.Array | Quantity) -> CSC:
return self._binary_op(other, operator.mul)
def __rmul__(self, other: jax.Array | Quantity) -> CSC:
return self._binary_rop(other, operator.mul)
def __div__(self, other: jax.Array | Quantity) -> CSC:
return self._binary_op(other, operator.truediv)
def __rdiv__(self, other: jax.Array | Quantity) -> CSC:
return self._binary_rop(other, operator.truediv)
def __truediv__(self, other) -> CSC:
return self.__div__(other)
def __rtruediv__(self, other) -> CSC:
return self.__rdiv__(other)
def __add__(self, other) -> CSC:
return self._binary_op(other, operator.add)
def __radd__(self, other) -> CSC:
return self._binary_rop(other, operator.add)
def __sub__(self, other) -> CSC:
return self._binary_op(other, operator.sub)
def __rsub__(self, other) -> CSC:
return self._binary_rop(other, operator.sub)
def __mod__(self, other) -> CSC:
return self._binary_op(other, operator.mod)
def __rmod__(self, other) -> CSC:
return self._binary_rop(other, operator.mod)
def __matmul__(self, other):
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError("matmul between two sparse objects.")
other = asarray(other)
data, other = promote_dtypes(self.data, other)
if other.ndim == 1:
return _csr_matvec(
data,
self.indices,
self.indptr,
other,
shape=self.shape[::-1],
transpose=True
)
elif other.ndim == 2:
return _csr_matmat(
data,
self.indices,
self.indptr,
other,
shape=self.shape[::-1],
transpose=True
)
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
def __rmatmul__(self, other):
if isinstance(other, (JAXSparse, SparseMatrix)):
raise NotImplementedError("matmul between two sparse objects.")
other = asarray(other)
data, other = promote_dtypes(self.data, other)
if other.ndim == 1:
return _csr_matvec(
data,
self.indices,
self.indptr,
other,
shape=self.shape[::-1],
transpose=False
)
elif other.ndim == 2:
other = other.T
r = _csr_matmat(
data,
self.indices,
self.indptr, other,
shape=self.shape[::-1],
transpose=False
)
return r.T
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
def tree_flatten(self):
return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr}
@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, = children
if aux_data.keys() != {'shape', 'indices', 'indptr'}:
raise ValueError(f"CSC.tree_unflatten: invalid {aux_data=}")
obj.shape = aux_data['shape']
obj.indices = aux_data['indices']
obj.indptr = aux_data['indptr']
return obj
Data = Union[jax.Array, Quantity]
Indices = jax.Array
Indptr = jax.Array
[docs]
def csr_fromdense(
mat: jax.Array | Quantity,
*, nse: int | None = None,
index_dtype: jax.typing.DTypeLike = np.int32
) -> CSR:
"""
Create a CSR-format sparse matrix from a dense matrix.
Parameters
----------
mat : jax.Array or Quantity
Dense 2-D array to be converted to CSR format.
nse : int or None, optional
Number of specified (non-zero) entries in ``mat``. If ``None``
(default), it is computed automatically from the input matrix.
index_dtype : dtype, optional
Data type for the sparse index arrays. Default is ``numpy.int32``.
Returns
-------
CSR
The CSR representation of the input matrix.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0., 0.], [0., 2., 3.]])
>>> csr = susparse.csr_fromdense(dense)
>>> csr.shape
(2, 3)
>>> csr.todense()
Array([[1., 0., 0.],
[0., 2., 3.]], dtype=float32)
"""
if nse is None:
nse = int((get_mantissa(mat) != 0).sum())
nse_int = concrete_or_error(operator.index, nse, "csr_fromdense nse argument")
return CSR(_csr_fromdense(mat, nse=nse_int, index_dtype=index_dtype), shape=mat.shape)
[docs]
def csr_todense(mat: CSR) -> jax.Array | Quantity:
"""
Convert a CSR-format sparse matrix to a dense matrix.
Parameters
----------
mat : CSR
The CSR sparse matrix to convert.
Returns
-------
jax.Array or Quantity
Dense 2-D array equivalent to ``mat``.
Raises
------
TypeError
If ``mat`` is not an instance of :class:`CSR`.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[5., 0.], [0., 6.]])
>>> csr = susparse.csr_fromdense(dense)
>>> susparse.csr_todense(csr)
Array([[5., 0.],
[0., 6.]], dtype=float32)
"""
if not isinstance(mat, CSR):
raise TypeError(f"Expected CSR, got {type(mat)}")
return _csr_todense(mat.data, mat.indices, mat.indptr, shape=mat.shape)
[docs]
def csc_todense(mat: CSC) -> jax.Array | Quantity:
"""
Convert a CSC-format sparse matrix to a dense matrix.
Parameters
----------
mat : CSC
The CSC sparse matrix to convert.
Returns
-------
jax.Array or Quantity
Dense 2-D array equivalent to ``mat``.
Raises
------
TypeError
If ``mat`` is not an instance of :class:`CSC`.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[5., 0.], [0., 6.]])
>>> csc = susparse.csc_fromdense(dense)
>>> susparse.csc_todense(csc)
Array([[5., 0.],
[0., 6.]], dtype=float32)
"""
if not isinstance(mat, CSC):
raise TypeError(f"Expected CSC, got {type(mat)}")
return mat.todense()
[docs]
def csc_fromdense(
mat: jax.Array | Quantity,
*,
nse: int | None = None,
index_dtype: jax.typing.DTypeLike = np.int32
) -> CSC:
"""
Create a CSC-format sparse matrix from a dense matrix.
Parameters
----------
mat : jax.Array or Quantity
Dense 2-D array to be converted to CSC format.
nse : int or None, optional
Number of specified (non-zero) entries in ``mat``. If ``None``
(default), it is computed automatically from the input matrix.
index_dtype : dtype, optional
Data type for the sparse index arrays. Default is ``numpy.int32``.
Returns
-------
CSC
The CSC representation of the input matrix.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0., 0.], [0., 2., 3.]])
>>> csc = susparse.csc_fromdense(dense)
>>> csc.shape
(2, 3)
>>> csc.todense()
Array([[1., 0., 0.],
[0., 2., 3.]], dtype=float32)
"""
if nse is not None:
nse = concrete_or_error(operator.index, nse, "csc_fromdense nse argument")
return CSC.fromdense(mat, nse=nse, index_dtype=index_dtype)
def _csr_fromdense(
mat: jax.Array | Quantity,
*,
nse: int,
index_dtype: jax.typing.DTypeLike = np.int32
) -> Tuple[Data, Indices, Indptr]:
"""Create CSR-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to CSR.
nse : number of specified entries in ``mat``
index_dtype : dtype of sparse indices
Returns:
data : array of shape ``(nse,)`` and dtype ``mat.dtype``.
indices : array of shape ``(nse,)`` and dtype ``index_dtype``
indptr : array of shape ``(mat.shape[0] + 1,)`` and dtype ``index_dtype``
"""
mat = asarray(mat)
mat, unit = split_mantissa_unit(mat)
nse = concrete_or_error(operator.index, nse, "nse argument of csr_fromdense()")
r = csr_fromdense_p.bind(mat, nse=nse, index_dtype=np.dtype(index_dtype))
if unit.is_unitless:
return r
else:
return maybe_decimal(r[0] * unit), r[1], r[2]
def _csr_todense(
data: jax.Array | Quantity,
indices: jax.Array,
indptr: jax.Array, *,
shape: Shape
) -> jax.Array:
"""Convert CSR-format sparse matrix to a dense matrix.
Args:
data : array of shape ``(nse,)``.
indices : array of shape ``(nse,)``
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
shape : length-2 tuple representing the matrix shape
Returns:
mat : array with specified shape and dtype matching ``data``
"""
data, unit = split_mantissa_unit(data)
mat = csr_todense_p.bind(data, indices, indptr, shape=shape)
return maybe_decimal(mat * unit)
def _csr_matvec(
data: jax.Array | Quantity,
indices: jax.Array,
indptr: jax.Array,
v: jax.Array | Quantity,
*,
shape: Shape,
transpose: bool = False
) -> jax.Array | Quantity:
"""Product of CSR sparse matrix and a dense vector.
Args:
data : array of shape ``(nse,)``.
indices : array of shape ``(nse,)``
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
v : array of shape ``(shape[0] if transpose else shape[1],)``
and dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
data, unitd = split_mantissa_unit(data)
v, unitv = split_mantissa_unit(v)
res = csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose)
return maybe_decimal(res * unitd * unitv)
def _csr_matmat(
data: jax.Array | Quantity,
indices: jax.Array,
indptr: jax.Array,
B: jax.Array | Quantity,
*,
shape: Shape,
transpose: bool = False
) -> jax.Array | Quantity:
"""Product of CSR sparse matrix and a dense matrix.
Args:
data : array of shape ``(nse,)``.
indices : array of shape ``(nse,)``
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix-matrix product.
"""
data, unitd = split_mantissa_unit(data)
B, unitb = split_mantissa_unit(B)
res = csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
return maybe_decimal(res * unitd * unitb)
@jax.jit
def _csr_to_coo(indices: jax.Array, indptr: jax.Array) -> Tuple[jax.Array, jax.Array]:
"""Given CSR (indices, indptr) return COO (row, col)"""
return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices