Source code for brainevent._jit_scalar.main

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


from typing import Union, Tuple, Optional, Dict

import brainunit as u
import jax
import numpy as np

from brainevent._compatible_import import Tracer
from brainevent._data import JITCMatrix
from brainevent._event.binary import BinaryArray
from brainevent._typing import MatrixShape, WeightScalar, Prob, Seed
from .binary import binary_jitsmv, binary_jitsmm
from .float import jits, jitsmv, jitsmm

__all__ = [
    'JITCScalarR',
    'JITCScalarC',
]


class JITCScalarMatrix(JITCMatrix):
    """
    Base class for Just-In-Time Connectivity Scalar-weight matrices.

    This abstract class serves as the foundation for sparse matrix representations
    that use homogeneous (scalar) weights with stochastic connectivity patterns. It
    stores a single weight value applied to all non-zero elements, along with
    connectivity probability and a random seed that determines the sparse structure.

    Parameters
    ----------
    weight : WeightScalar or Tuple[WeightScalar, Prob, Seed]
        Either the homogeneous weight value for all non-zero elements,
        or a tuple containing ``(weight, prob, seed)``.
    prob : Prob, optional
        Connection probability determining matrix sparsity.
    seed : Seed, optional
        Random seed for reproducible sparse structure generation.
    shape : MatrixShape
        The shape of the matrix as a tuple ``(rows, columns)``.
    corder : bool, optional
        Memory layout order flag, by default False.
    backend : str or None, optional
        The computation backend to use.

    Returns
    -------
    JITCScalarMatrix
        A new scalar-weight JIT connectivity matrix instance.

    Raises
    ------
    ValueError
        If ``prob`` is not a scalar, is not finite, or is outside ``[0, 1]``.

    See Also
    --------
    JITCScalarR : Row-oriented concrete subclass.
    JITCScalarC : Column-oriented concrete subclass.
    JITCMatrix : Parent class for all JIT connectivity matrices.

    Notes
    -----
    The matrix ``W`` is defined by a scalar weight ``w``, a connection probability
    ``p``, and a deterministic pseudo-random seed ``s``. Each element is given by:

    ``W[i, j] = w * B[i, j]``

    where ``B[i, j] ~ Bernoulli(p)`` is a binary mask whose realization is
    fully determined by the seed ``s``. Specifically, the mask is generated by
    a deterministic hash-based PRNG that, for a given ``(i, j, s)`` triple,
    always produces the same binary outcome. This means:

    - The same ``(weight, prob, seed, shape)`` always produces the identical matrix.
    - The expected number of non-zeros per row is ``p * n_cols``.
    - The matrix is never materialized in memory; it is regenerated on-the-fly
      during each operation (matvec, matmat, todense).

    The connection length parameter ``clen = 2 / p`` controls the average stride
    between successive non-zero entries in the sampling loop.

    Examples
    --------

    .. code-block:: python

        >>> from brainevent import JITCScalarR
        >>> mat = JITCScalarR((0.5, 0.1, 42), shape=(100, 50))
        >>> mat.weight   # 0.5
        >>> mat.prob     # 0.1
        >>> mat.seed     # 42

    Attributes
    ----------
    weight : Union[jax.Array, u.Quantity]
        The homogeneous weight value applied to all non-zero elements in the matrix.
        Can be a plain JAX array or a quantity with units.
    prob : Union[float, jax.Array]
        Connection probability determining the sparsity of the matrix.
        Values range from 0 (no connections) to 1 (fully connected).
    seed : Union[int, jax.Array]
        Random seed controlling the specific pattern of connections.
        Using the same seed produces identical connectivity patterns.
    shape : MatrixShape
        Tuple specifying the dimensions of the matrix as (rows, columns).
    corder : bool
        Flag indicating the memory layout order of the matrix.
        False (default) for Fortran-order (column-major), True for C-order (row-major).
    backend : str or None
        The computation backend to use (e.g., ``'numba'``, ``'pallas'``). If ``None``,
        the default backend is selected automatically.
    """
    __module__ = 'brainevent'

    weight: Union[jax.Array, u.Quantity]
    prob: Union[float, jax.Array]
    seed: Union[int, jax.Array]
    shape: MatrixShape
    corder: bool
    backend: Optional[str]

    def __init__(
        self,
        weight,
        prob=None,
        seed=None,
        *,
        shape: MatrixShape,
        corder: bool = False,
        backend: Optional[str] = None,
        buffers: Optional[Dict] = None,
    ):
        """
        Initialize a homogeneous sparse just-in-time connectivity matrix.

        Parameters
        ----------
        weight : WeightScalar or Tuple[WeightScalar, Prob, Seed]
            Either the homogeneous weight value for all non-zero elements,
            or a tuple containing (weight, prob, seed).
        prob : Prob, optional
            Connection probability determining matrix sparsity.
            If None, ``weight`` is treated as a tuple of (weight, prob, seed).
        seed : Seed, optional
            Random seed for reproducible sparse structure generation.
        shape : MatrixShape
            The shape of the matrix as a tuple (rows, columns).
        corder : bool, optional
            Memory layout order flag, by default False.
            - False: Fortran-order (column-major)
            - True: C-order (row-major)
        backend : str or None, optional
            The computation backend to use. If ``None``, the default backend is
            selected automatically.

        Raises
        ------
        ValueError
            If ``prob`` is not a scalar, is not finite, or is outside [0, 1].

        Notes
        -----
        The constructor extracts the components from the data tuple and sets them
        as instance attributes. The weight is converted to a JAX array if it's not
        already one, preserving any attached units.
        """
        if prob is None and seed is None:
            data = weight
        else:
            data = (weight, prob, seed)
        weight, self.prob, self.seed = data
        if not isinstance(self.prob, Tracer):
            prob = np.asarray(self.prob)
            if prob.size != 1:
                raise ValueError(f"prob must be a scalar, but got shape {prob.shape}.")
            prob = float(prob.item())
            if not np.isfinite(prob):
                raise ValueError(f"prob must be finite, but got {prob}.")
            if not (0. <= prob <= 1.):
                raise ValueError(f"prob must be in [0, 1], but got {prob}.")
        self.weight = u.math.asarray(weight)
        self.corder = corder
        self.backend = backend
        super().__init__(data, shape=shape, buffers=buffers)

    def __repr__(self):
        """
        Return a string representation of the homogeneous matrix.

        Returns
        -------
        str
            A string showing the class name, shape, weight value, probability,
            seed, and corder flag of the matrix instance.

        Examples
        --------
        >>> matrix = JITCScalarMatrix((0.5, 0.1, 42), shape=(10, 10))
        >>> repr(matrix)
        'JITHomoMatrix(shape=(10, 10), weight=0.5, prob=0.1, seed=42, corder=False)'
        """
        return (
            f"{self.__class__.__name__}("
            f"shape={self.shape}, "
            f"weight={self.weight}, "
            f"prob={self.prob}, "
            f"seed={self.seed}, "
            f"corder={self.corder},"
            f"backend={self.backend},"
            f")"
        )

    @property
    def dtype(self):
        """
        Get the data type of the matrix elements.

        Returns
        -------
        dtype
            The data type of the weight values in the matrix.

        Notes
        -----
        This property inherits the dtype directly from the weight attribute,
        ensuring consistent data typing throughout operations involving this matrix.
        """
        return self.weight.dtype

    @property
    def data(self) -> Tuple[WeightScalar, Prob, Seed]:
        """
        Returns the core data components of the homogeneous matrix.

        This property provides access to the three fundamental components that define
        the sparse matrix: weight values, connection probabilities, and the random seed.
        It's used by the tree_flatten method to make the class compatible with JAX
        transformations.

        Returns
        -------
        Tuple[Weight, Prob, Seed]
            A tuple containing:
            - weight: The homogeneous weight value for non-zero elements
            - prob: Connection probability for the sparse structure
            - seed: Random seed used for generating the sparse connectivity pattern
        """
        return self.weight, self.prob, self.seed

    def with_data(self, weight: WeightScalar):
        """
        Create a new matrix instance with updated weight data but preserving other properties.

        This method returns a new instance of the same class with the provided weight value,
        while keeping the same probability, seed, shape, and other configuration parameters.
        It's useful for updating weights without changing the connectivity pattern.

        Parameters
        ----------
        weight : Weight
            The new weight value to use. Must have the same shape and unit as the current weight.

        Returns
        -------
        JITCScalarMatrix
            A new matrix instance of the same concrete type with the updated weight.

        Raises
        ------
        AssertionError
            If the provided weight has a different shape or unit than the current weight.
        """
        weight = u.math.asarray(weight)
        assert weight.shape == self.weight.shape
        assert u.get_unit(weight) == u.get_unit(self.weight)
        return type(self)(
            (weight, self.prob, self.seed),
            shape=self.shape,
            corder=self.corder,
            backend=self.backend,
            buffers=self.buffers,
        )

    def tree_flatten(self):
        """
        Flatten the matrix into leaves and auxiliary data for JAX pytree registration.

        Returns
        -------
        tuple
            A 2-tuple where the first element is a tuple of JAX-traceable leaves
            ``(weight, prob, seed)`` and the second element is a dict of static
            auxiliary data ``{'shape': ..., 'corder': ..., 'backend': ...}``.

        See Also
        --------
        tree_unflatten : Reconstruct the matrix from flattened representation.
        """
        aux = {'shape': self.shape, 'corder': self.corder, 'backend': self.backend}
        return (self.weight, self.prob, self.seed), (aux, self.buffers)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """
        Reconstruct a matrix instance from flattened pytree data.

        Parameters
        ----------
        aux_data : dict
            Dictionary of static auxiliary data containing ``'shape'``, ``'corder'``,
            and ``'backend'`` keys.
        children : tuple
            Tuple of JAX-traceable leaves ``(weight, prob, seed)``.

        Returns
        -------
        JITCScalarMatrix
            A reconstructed matrix instance with attributes restored from both
            ``children`` and ``aux_data``.

        See Also
        --------
        tree_flatten : Flatten the matrix for JAX pytree operations.
        """
        obj = object.__new__(cls)
        obj.weight, obj.prob, obj.seed = children
        aux_data, buffer = aux_data
        obj._buffer_registry = set(buffer.keys())
        for k, v in aux_data.items():
            setattr(obj, k, v)
        for k, v in buffer.items():
            setattr(obj, k, v)
        return obj

    def _check(self, other, op):
        """
        Validate compatibility between two JITScalarMatrix instances for binary operations.

        Parameters
        ----------
        other : JITCScalarMatrix
            The other matrix to check compatibility against.
        op : callable
            The binary operation being attempted, used in error messages.

        Raises
        ------
        NotImplementedError
            If the two matrices have different seeds, different corder flags,
            different probabilities, or different shapes.
        """
        if not (isinstance(other.seed, Tracer) and isinstance(self.seed, Tracer)):
            if self.seed != other.seed:
                raise NotImplementedError(
                    f"binary operation {op} between two {self.__class__.__name__} objects with different seeds "
                    f"is not implemented currently."
                )
        else:
            raise NotImplementedError(
                f"binary operation {op} between two {self.__class__.__name__} objects with tracing seeds "
                f"is not implemented currently."
            )
        if self.corder != other.corder:
            raise NotImplementedError(
                f"binary operation {op} between two {self.__class__.__name__} objects with different corder "
                f"is not implemented currently."
            )
        if self.prob != other.prob:
            raise NotImplementedError(
                f"binary operation {op} between two {self.__class__.__name__} objects "
                f"with different prob is not supported."
            )
        if self.shape != other.shape:
            raise NotImplementedError(
                f"binary operation {op} between two {self.__class__.__name__} objects "
                f"with different shapes is not supported."
            )


@jax.tree_util.register_pytree_node_class
class JITCScalarR(JITCScalarMatrix):
    """
    Just-In-Time Connectivity Homogeneous matrix with Row-oriented representation.

    This class represents a row-oriented homogeneous sparse matrix optimized for JAX-based
    transformations. It follows the Compressed Sparse Row (CSR) format conceptually, storing
    a uniform weight value for all non-zero elements in the matrix, along with probability
    and seed information to determine the sparse structure.

    The class is designed for efficient neural network connectivity patterns where weights
    are homogeneous (identical) but connectivity is sparse and stochastically determined.
    The row-oriented structure makes row-based operations more efficient than column-based ones.

    Attributes
    ----------
    weight : Union[jax.Array, u.Quantity]
        The homogeneous value used for all non-zero elements in the matrix.
        Can be a plain JAX array or a quantity with units.
    prob : Union[float, jax.Array]
        Probability for each potential connection. Controls the sparsity level
        with 0.0 meaning no connections and 1.0 meaning all possible connections.
    seed : Union[int, jax.Array]
        Random seed used for initialization of the sparse structure.
        Using the same seed produces identical connectivity patterns.
    shape : MatrixShape
        The shape of the matrix as a tuple (rows, cols).
    corder : bool
        Flag indicating the memory layout order of the matrix.
        False (default) for Fortran-order (column-major), True for C-order (row-major).
    dtype
        The data type of the matrix elements (property inherited from parent).

    Examples
    --------

    .. code-block:: python

        >>> import jax
        >>> import brainunit as u
        >>> from brainevent import JITCScalarR

        # Create a homogeneous matrix with value 1.5, probability 0.1, and seed 42
        >>> homo_matrix = JITCScalarR((1.5, 0.1, 42), shape=(10, 10))
        >>> homo_matrix
        JITCHomoR(shape=(10, 10), weight=1.5, prob=0.1, seed=42, corder=False)

        # Create a matrix with units
        >>> weighted_matrix = JITCScalarR((1.5 * u.mV, 0.1, 42), shape=(10, 10))
        >>> weighted_matrix
        JITCHomoR(shape=(10, 10), weight=1.5 mV, prob=0.1, seed=42, corder=False)

        # Perform matrix-vector multiplication
        >>> vec = jax.numpy.ones(10)
        >>> result = homo_matrix @ vec
        >>> result.shape  # (10,)

        # Apply scalar operations
        >>> scaled = homo_matrix * 2.0
        >>> scaled.weight  # 3.0

        # Arithmetic operations maintain the sparse structure
        >>> neg_matrix = -homo_matrix
        >>> neg_matrix.weight  # -1.5

        # Convert to dense representation
        >>> dense_matrix = homo_matrix.todense()
        >>> dense_matrix.shape  # (10, 10)

        # Transpose operation returns a column-oriented matrix
        >>> col_matrix = homo_matrix.transpose()
        >>> isinstance(col_matrix, JITCScalarC)  # True
        >>> col_matrix.shape  # (10, 10)

    Notes
    -----
    The mathematical model for this matrix is:

    ``W[i, j] = w * Bernoulli(p)``

    where ``w`` is the scalar weight (``self.weight``), ``p`` is the connection
    probability (``self.prob``), and the Bernoulli draw is determined by a
    deterministic hash from the seed. The expected value of each element is
    ``E[W[i, j]] = w * p`` and the variance is ``Var[W[i, j]] = w^2 * p * (1 - p)``.

    For a matrix-vector product ``y = W @ x``:

    ``y[i] = sum_{j in C(i)} w * x[j]``

    where ``C(i)`` is the deterministic random connection set for row ``i``,
    with ``|C(i)| ~ Binomial(n_cols, p)``.

    Key properties:

    - JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)
    - More memory-efficient than dense matrices for sparse connectivity patterns
    - Well-suited for neural network connectivity matrices with uniform weights
    - Optimized for matrix-vector operations common in neural simulations
    - The matrix is implicitly constructed based on the probability and seed;
      the actual sparse structure is materialized only when needed
    - When used with units (e.g., ``u.mV``), units are preserved through operations

    See Also
    --------
    JITCScalarC : Column-oriented counterpart of this class.
    JITCScalarMatrix : Base class providing shared functionality.
    """
    __module__ = 'brainevent'

[docs] def todense(self) -> Union[jax.Array, u.Quantity]: """ Convert the sparse scalar-weight matrix to dense format. Generates a full dense representation of the sparse matrix by materializing all entries ``W[i, j] = w * Bernoulli(p)`` determined by the probability and seed. Parameters ---------- None Returns ------- Union[jax.Array, u.Quantity] A dense matrix with the same shape as the sparse matrix. The data type will match the weight's data type, and if the weight has units (is a ``u.Quantity``), the returned array will have the same units. Raises ------ None See Also -------- jits : The underlying function that materializes the matrix. Notes ----- The dense matrix is generated by iterating over all ``(i, j)`` positions and placing the scalar weight ``w`` at each position where the deterministic PRNG indicates a connection: ``dense[i, j] = w if hash(seed, i, j) < p else 0`` Examples -------- .. code-block:: python >>> import brainunit as u >>> from brainevent import JITCScalarR >>> sparse_matrix = JITCScalarR((1.5 * u.mV, 0.5, 42), shape=(10, 4)) >>> dense_matrix = sparse_matrix.todense() >>> dense_matrix.shape # (10, 4) """ return jits( self.weight, self.prob, self.seed, shape=self.shape, transpose=False, corder=self.corder, backend=self.backend, )
[docs] def transpose(self, axes=None) -> 'JITCScalarC': """ Transposes the row-oriented matrix into a column-oriented matrix. This method returns a column-oriented matrix (JITCScalarC) with rows and columns swapped, preserving the same weight, probability, and seed values. The transpose operation effectively converts between row-oriented and column-oriented sparse matrix formats. Parameters ---------- axes : None Not supported. This parameter exists for compatibility with the NumPy API but only None is accepted. Returns ------- JITCScalarC A new column-oriented homogeneous matrix with transposed dimensions. Raises ------ AssertionError If axes is not None, since partial axis transposition is not supported. Examples -------- >>> import jax >>> import brainunit as u >>> from brainevent import JITCScalarR >>> >>> # Create a row-oriented matrix >>> row_matrix = JITCScalarR((1.5, 0.5, 42), shape=(30, 5)) >>> print(row_matrix.shape) # (30, 5) >>> >>> # Transpose to column-oriented matrix >>> col_matrix = row_matrix.transpose() >>> print(col_matrix.shape) # (5, 30) >>> isinstance(col_matrix, JITCScalarC) # True """ assert axes is None, "transpose does not support axes argument." return JITCScalarC( (self.weight, self.prob, self.seed), shape=(self.shape[1], self.shape[0]), corder=not self.corder, backend=self.backend, buffers=self.buffers, )
def _new_mat(self, weight, prob=None, seed=None): """ Create a new ``JITCScalarR`` with the given weight while preserving structure. Parameters ---------- weight : jax.Array or u.Quantity The new weight value. prob : float or None, optional Connection probability. If ``None``, uses ``self.prob``. seed : int or None, optional Random seed. If ``None``, uses ``self.seed``. Returns ------- JITCScalarR A new row-oriented matrix with the updated weight. """ return JITCScalarR( ( weight, self.prob if prob is None else prob, self.seed if seed is None else seed ), shape=self.shape, corder=self.corder, backend=self.backend, buffers=self.buffers, ) def _unitary_op(self, op) -> 'JITCScalarR': """ Apply a unary operation to the weight of this matrix. Parameters ---------- op : callable A unary function to apply to the weight (e.g., ``operator.neg``). Returns ------- JITCScalarR A new matrix with the transformed weight. """ return self._new_mat(op(self.weight), self.prob, self.seed) def _binary_op(self, other, op) -> 'JITCScalarR': """ Apply a binary operation between this matrix and another operand. Parameters ---------- other : JITCScalarR or u.sparse.SparseMatrix or scalar The right-hand operand for the binary operation. op : callable A binary function (e.g., ``operator.mul``). Returns ------- JITCScalarR A new matrix whose weight is ``op(self.weight, other_weight)``. Raises ------ NotImplementedError If ``other`` is a sparse matrix of a different type or a non-scalar array. """ if isinstance(other, JITCScalarR): self._check(other, op) return self._new_mat(op(self.weight, other.weight)) if isinstance(other, u.sparse.SparseMatrix): raise NotImplementedError(f"binary operation {op} between two sparse objects.") other = u.math.asarray(other) if other.size == 1: return self._new_mat(op(self.weight, other)) else: raise NotImplementedError(f"mul with object of shape {other.shape}") def _binary_rop(self, other, op) -> 'JITCScalarR': """ Apply a reflected binary operation (other op self). Parameters ---------- other : JITCScalarR or u.sparse.SparseMatrix or scalar The left-hand operand for the binary operation. op : callable A binary function (e.g., ``operator.mul``). Returns ------- JITCScalarR A new matrix whose weight is ``op(other_weight, self.weight)``. Raises ------ NotImplementedError If ``other`` is a sparse matrix of a different type or a non-scalar array. """ if isinstance(other, JITCScalarR): self._check(other, op) return self._new_mat(op(other.weight, self.weight)) if isinstance(other, u.sparse.SparseMatrix): raise NotImplementedError(f"binary operation {op} between two sparse objects.") other = u.math.asarray(other) if other.size == 1: return self._new_mat(op(other, self.weight)) else: raise NotImplementedError(f"mul with object of shape {other.shape}") def __matmul__(self, other) -> Union[jax.Array, u.Quantity]: """ Compute ``self @ other`` (matrix-vector or matrix-matrix product). Dispatches to the appropriate kernel depending on the type and dimensionality of ``other``: * 1-D ``BinaryArray`` -- event-driven matrix-vector product via ``binary_jitsmv``. * 2-D ``BinaryArray`` -- event-driven matrix-matrix product via ``binary_jitsmm``. * 1-D dense array -- float matrix-vector product via ``jitsmv``. * 2-D dense array -- float matrix-matrix product via ``jitsmm``. Parameters ---------- other : BinaryArray or jax.Array or u.Quantity The right-hand operand. Must be 1-D (vector) or 2-D (matrix). Returns ------- Union[jax.Array, u.Quantity] The result of the matrix multiplication. Units from both the weight and the operand are preserved. Raises ------ NotImplementedError If ``other`` is a sparse matrix or has more than 2 dimensions. """ # csr @ other if isinstance(other, u.sparse.SparseMatrix): raise NotImplementedError("matmul between two sparse objects.") weight = self.weight if isinstance(other, BinaryArray): other = other.value if other.ndim == 1: # JIT matrix @ events return binary_jitsmv(weight, self.prob, other, self.seed, shape=self.shape, transpose=False, corder=self.corder, backend=self.backend) elif other.ndim == 2: # JIT matrix @ events return binary_jitsmm(weight, self.prob, other, self.seed, shape=self.shape, transpose=False, corder=self.corder, backend=self.backend) else: raise NotImplementedError(f"matmul with object of shape {other.shape}") else: other = u.math.asarray(other) weight, other = u.math.promote_dtypes(self.weight, other) if other.ndim == 1: # JIT matrix @ vector return jitsmv( weight, self.prob, other, self.seed, shape=self.shape, transpose=False, corder=self.corder, backend=self.backend, ) elif other.ndim == 2: # JIT matrix @ matrix return jitsmm( weight, self.prob, other, self.seed, shape=self.shape, transpose=False, corder=self.corder, backend=self.backend, ) else: raise NotImplementedError(f"matmul with object of shape {other.shape}") def __rmatmul__(self, other) -> Union[jax.Array, u.Quantity]: """ Compute ``other @ self`` (right matrix multiplication). This is equivalent to ``self.T @ other`` for vectors, or ``(self.T @ other.T).T`` for matrices. Parameters ---------- other : BinaryArray or jax.Array or u.Quantity The left-hand operand. Must be 1-D (vector) or 2-D (matrix). Returns ------- Union[jax.Array, u.Quantity] The result of the matrix multiplication with units preserved. Raises ------ NotImplementedError If ``other`` is a sparse matrix or has more than 2 dimensions. """ # other @ csr if isinstance(other, u.sparse.SparseMatrix): raise NotImplementedError("matmul between two sparse objects.") weight = self.weight if isinstance(other, BinaryArray): other = other.value if other.ndim == 1: # # vector @ JIT matrix # == # JIT matrix.T @ vector # return binary_jitsmv( weight, self.prob, other, self.seed, shape=self.shape, transpose=True, corder=not self.corder, backend=self.backend, ) elif other.ndim == 2: # # matrix @ JIT matrix # == # (JIT matrix.T @ matrix.T).T # r = binary_jitsmm( weight, self.prob, other.T, self.seed, shape=self.shape, transpose=True, corder=not self.corder, backend=self.backend, ) return r.T else: raise NotImplementedError(f"matmul with object of shape {other.shape}") else: other = u.math.asarray(other) weight, other = u.math.promote_dtypes(self.weight, other) if other.ndim == 1: # # vector @ JIT matrix # == # JIT matrix.T @ vector # return jitsmv( weight, self.prob, other, self.seed, shape=self.shape, transpose=True, corder=not self.corder, # This is import to generate the same matrix as ``.todense()`` backend=self.backend, ) elif other.ndim == 2: # # matrix @ JIT matrix # == # (JIT matrix.T @ matrix.T).T # r = jitsmm( weight, self.prob, other.T, self.seed, shape=self.shape, transpose=True, corder=not self.corder, # This is import to generate the same matrix as ``.todense()`` backend=self.backend, ) return r.T else: raise NotImplementedError(f"matmul with object of shape {other.shape}") @jax.tree_util.register_pytree_node_class class JITCScalarC(JITCScalarMatrix): """ Just-In-Time Connectivity Homogeneous matrix with Column-oriented representation. This class represents a column-oriented homogeneous sparse matrix optimized for JAX-based transformations. It follows the Compressed Sparse Column (CSC) format conceptually, storing a uniform weight value for all non-zero elements in the matrix, along with probability and seed information to determine the sparse structure. The column-oriented structure makes column-based operations more efficient than row-based ones, making this class the transpose-oriented counterpart to JITCScalarR. Attributes ---------- weight : Union[jax.Array, u.Quantity] The homogeneous value used for all non-zero elements in the matrix. Can be a plain JAX array or a quantity with units. prob : Union[float, jax.Array] Probability for each potential connection. Controls the sparsity level with 0.0 meaning no connections and 1.0 meaning all possible connections. seed : Union[int, jax.Array] Random seed used for initialization of the sparse structure. Using the same seed produces identical connectivity patterns. shape : MatrixShape The shape of the matrix as a tuple (rows, cols). corder : bool Flag indicating the memory layout order of the matrix. False (default) for Fortran-order (column-major), True for C-order (row-major). dtype The data type of the matrix elements (property inherited from parent). Examples -------- .. code-block:: python >>> import jax >>> import brainunit as u >>> from brainevent import JITCScalarC # Create a homogeneous matrix with value 1.5, probability 0.1, and seed 42 >>> homo_matrix = JITCScalarC((1.5, 0.1, 42), shape=(10, 10)) >>> homo_matrix JITCHomoC(shape=(10, 10), weight=1.5, prob=0.1, seed=42, corder=False) # Create a matrix with units >>> weighted_matrix = JITCScalarC((1.5 * u.mV, 0.1, 42), shape=(10, 10)) >>> weighted_matrix JITCHomoC(shape=(10, 10), weight=1.5 mV, prob=0.1, seed=42, corder=False) # Perform matrix-vector multiplication >>> vec = jax.numpy.ones(10) >>> result = homo_matrix @ vec >>> result.shape # (10,) # Apply scalar operations >>> scaled = homo_matrix * 2.0 >>> scaled.weight # 3.0 # Arithmetic operations maintain the sparse structure >>> neg_matrix = -homo_matrix >>> neg_matrix.weight # -1.5 # Convert to dense representation >>> dense_matrix = homo_matrix.todense() >>> dense_matrix.shape # (10, 10) # Transpose operation returns a row-oriented matrix >>> row_matrix = homo_matrix.transpose() >>> isinstance(row_matrix, JITCScalarR) # True >>> row_matrix.shape # (10, 10) Notes ----- The mathematical model is the same as ``JITCScalarR``: ``W[i, j] = w * Bernoulli(p)`` where ``w`` is the scalar weight, ``p`` is the connection probability, and the Bernoulli draw is fully determined by the seed. The column-oriented representation means that ``JITCScalarC`` is conceptually the transpose of a ``JITCScalarR`` matrix with swapped dimensions. Key properties: - JAX PyTree compatible for use with JAX transformations (jit, grad, vmap) - More memory-efficient than dense matrices for sparse connectivity patterns - More efficient than ``JITCScalarR`` for column-based operations - Well-suited for neural network connectivity matrices with uniform weights - The matrix is implicitly constructed based on the probability and seed; the actual sparse structure is materialized only when needed - When used with units (e.g., ``u.mV``), units are preserved through operations See Also -------- JITCScalarR : Row-oriented counterpart of this class. JITCScalarMatrix : Base class providing shared functionality. """ __module__ = 'brainevent'
[docs] def todense(self) -> Union[jax.Array, u.Quantity]: """ Convert the sparse column-oriented scalar-weight matrix to dense format. Generates a full dense representation of the sparse matrix by materializing all entries ``W[i, j] = w * Bernoulli(p)`` determined by the probability and seed. The generated dense matrix always has ``self.shape``. Parameters ---------- None Returns ------- Union[jax.Array, u.Quantity] A dense matrix with the same shape as the sparse matrix. The data type will match the weight's data type, and if the weight has units (is a ``u.Quantity``), the returned array will have the same units. Raises ------ None See Also -------- jits : The underlying function that materializes the matrix. Notes ----- The dense matrix is generated identically to ``JITCScalarR.todense()``: ``dense[i, j] = w if hash(seed, i, j) < p else 0`` Examples -------- .. code-block:: python >>> import brainunit as u >>> from brainevent import JITCScalarC >>> sparse_matrix = JITCScalarC((1.5 * u.mV, 0.5, 42), shape=(3, 10)) >>> dense_matrix = sparse_matrix.todense() >>> dense_matrix.shape # (3, 10) """ return jits( self.weight, self.prob, self.seed, shape=self.shape, transpose=False, corder=self.corder, backend=self.backend, )
[docs] def transpose(self, axes=None) -> 'JITCScalarR': """ Transposes the column-oriented matrix into a row-oriented matrix. This method returns a row-oriented matrix (JITCScalarR) with rows and columns swapped, preserving the same weight, probability, and seed values. The transpose operation effectively converts between column-oriented and row-oriented sparse matrix formats. Parameters ---------- axes : None Not supported. This parameter exists for compatibility with the NumPy API but only None is accepted. Returns ------- JITCScalarR A new row-oriented homogeneous matrix with transposed dimensions. Raises ------ AssertionError If axes is not None, since partial axis transposition is not supported. Examples -------- >>> import jax >>> import brainunit as u >>> from brainevent import JITCScalarC >>> >>> # Create a column-oriented matrix >>> col_matrix = JITCScalarC((1.5, 0.5, 42), shape=(3, 5)) >>> print(col_matrix.shape) # (3, 5) >>> >>> # Transpose to row-oriented matrix >>> row_matrix = col_matrix.transpose() >>> print(row_matrix.shape) # (5, 3) >>> isinstance(row_matrix, JITCScalarR) # True """ assert axes is None, "transpose does not support axes argument." return JITCScalarR( (self.weight, self.prob, self.seed), shape=(self.shape[1], self.shape[0]), corder=not self.corder, backend=self.backend, buffers=self.buffers, )
def _new_mat(self, weight, prob=None, seed=None): """ Create a new ``JITCScalarC`` with the given weight while preserving structure. Parameters ---------- weight : jax.Array or u.Quantity The new weight value. prob : float or None, optional Connection probability. If ``None``, uses ``self.prob``. seed : int or None, optional Random seed. If ``None``, uses ``self.seed``. Returns ------- JITCScalarC A new column-oriented matrix with the updated weight. """ return JITCScalarC( ( weight, self.prob if prob is None else prob, self.seed if seed is None else seed ), shape=self.shape, corder=self.corder, backend=self.backend, buffers=self.buffers, ) def _unitary_op(self, op) -> 'JITCScalarC': """ Apply a unary operation to the weight of this matrix. Parameters ---------- op : callable A unary function to apply to the weight (e.g., ``operator.neg``). Returns ------- JITCScalarC A new matrix with the transformed weight. """ return self._new_mat(op(self.weight)) def _binary_op(self, other, op) -> 'JITCScalarC': """ Apply a binary operation between this matrix and another operand. Parameters ---------- other : JITCScalarC or u.sparse.SparseMatrix or scalar The right-hand operand for the binary operation. op : callable A binary function (e.g., ``operator.mul``). Returns ------- JITCScalarC A new matrix whose weight is ``op(self.weight, other_weight)``. Raises ------ NotImplementedError If ``other`` is a sparse matrix of a different type or a non-scalar array. """ if isinstance(other, JITCScalarC): self._check(other, op) return self._new_mat(op(self.weight, other.weight)) if isinstance(other, u.sparse.SparseMatrix): raise NotImplementedError(f"binary operation {op} between two sparse objects.") other = u.math.asarray(other) if other.size == 1: return self._new_mat(op(self.weight, other)) else: raise NotImplementedError(f"mul with object of shape {other.shape}") def _binary_rop(self, other, op) -> 'JITCScalarC': """ Apply a reflected binary operation (other op self). Parameters ---------- other : JITCScalarC or u.sparse.SparseMatrix or scalar The left-hand operand for the binary operation. op : callable A binary function (e.g., ``operator.mul``). Returns ------- JITCScalarC A new matrix whose weight is ``op(other_weight, self.weight)``. Raises ------ NotImplementedError If ``other`` is a sparse matrix of a different type or a non-scalar array. """ if isinstance(other, JITCScalarC): self._check(other, op) return self._new_mat(op(other.weight, self.weight)) if isinstance(other, u.sparse.SparseMatrix): raise NotImplementedError(f"binary operation {op} between two sparse objects.") other = u.math.asarray(other) if other.size == 1: return self._new_mat(op(other, self.weight)) else: raise NotImplementedError(f"mul with object of shape {other.shape}") def __matmul__(self, other) -> Union[jax.Array, u.Quantity]: """ Compute ``self @ other`` (matrix-vector or matrix-matrix product). Since ``JITCScalarC`` is column-oriented, multiplication is implemented by transposing to the row-oriented form internally: * ``self @ v`` is computed as ``JITCScalarR.T @ v``. * ``self @ B`` is computed as ``JITCScalarR.T @ B``. Parameters ---------- other : BinaryArray or jax.Array or u.Quantity The right-hand operand. Must be 1-D (vector) or 2-D (matrix). Returns ------- Union[jax.Array, u.Quantity] The result of the matrix multiplication with units preserved. Raises ------ NotImplementedError If ``other`` is a sparse matrix or has more than 2 dimensions. """ # csr @ other if isinstance(other, u.sparse.SparseMatrix): raise NotImplementedError("matmul between two sparse objects.") weight = self.weight if isinstance(other, BinaryArray): other = other.value if other.ndim == 1: # JITC_R matrix.T @ vector # == # vector @ JITC_R matrix return binary_jitsmv( weight, self.prob, other, self.seed, shape=self.shape[::-1], transpose=True, corder=self.corder, backend=self.backend, ) elif other.ndim == 2: # JITC_R matrix.T @ matrix # == # (matrix.T @ JITC_R matrix).T return binary_jitsmm( weight, self.prob, other, self.seed, shape=self.shape[::-1], transpose=True, corder=self.corder, backend=self.backend, ) else: raise NotImplementedError(f"matmul with object of shape {other.shape}") else: other = u.math.asarray(other) weight, other = u.math.promote_dtypes(self.weight, other) if other.ndim == 1: # JITC_R matrix.T @ vector # == # vector @ JITC_R matrix return jitsmv( weight, self.prob, other, self.seed, shape=self.shape[::-1], transpose=True, corder=self.corder, backend=self.backend, ) elif other.ndim == 2: # JITC_R matrix.T @ matrix # == # (matrix.T @ JITC_R matrix).T return jitsmm( weight, self.prob, other, self.seed, shape=self.shape[::-1], transpose=True, corder=self.corder, backend=self.backend, ) else: raise NotImplementedError(f"matmul with object of shape {other.shape}") def __rmatmul__(self, other) -> Union[jax.Array, u.Quantity]: """ Compute ``other @ self`` (right matrix multiplication). Since ``JITCScalarC`` is the transpose of a row-oriented matrix, ``other @ self`` is implemented via the underlying row-oriented kernels: * ``v @ self`` is computed as ``JITCScalarR @ v``. * ``B @ self`` is computed as ``(JITCScalarR @ B.T).T``. Parameters ---------- other : BinaryArray or jax.Array or u.Quantity The left-hand operand. Must be 1-D (vector) or 2-D (matrix). Returns ------- Union[jax.Array, u.Quantity] The result of the matrix multiplication with units preserved. Raises ------ NotImplementedError If ``other`` is a sparse matrix or has more than 2 dimensions. """ # other @ csr if isinstance(other, u.sparse.SparseMatrix): raise NotImplementedError("matmul between two sparse objects.") weight = self.weight if isinstance(other, BinaryArray): other = other.value if other.ndim == 1: # # vector @ JITC_R matrix.T # == # JITC_R matrix @ vector # return binary_jitsmv( weight, self.prob, other, self.seed, shape=self.shape[::-1], transpose=False, corder=not self.corder, backend=self.backend, ) elif other.ndim == 2: # # matrix @ JITC_R matrix.T # == # (JITC_R matrix @ matrix.T).T # r = binary_jitsmm( weight, self.prob, other.T, self.seed, shape=self.shape[::-1], transpose=False, corder=not self.corder, backend=self.backend, ) return r.T else: raise NotImplementedError(f"matmul with object of shape {other.shape}") else: other = u.math.asarray(other) weight, other = u.math.promote_dtypes(self.weight, other) if other.ndim == 1: # # vector @ JITC_R matrix.T # == # JITC_R matrix @ vector # return jitsmv( weight, self.prob, other, self.seed, shape=self.shape[::-1], transpose=False, corder=not self.corder, backend=self.backend, ) elif other.ndim == 2: # # matrix @ JITC_R matrix.T # == # (JITC_R matrix @ matrix.T).T # r = jitsmm( weight, self.prob, other.T, self.seed, shape=self.shape[::-1], transpose=False, corder=not self.corder, backend=self.backend, ) return r.T else: raise NotImplementedError(f"matmul with object of shape {other.shape}")