# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Union, Tuple, Optional, Dict
import brainunit as u
import jax
import numpy as np
from brainevent._compatible_import import Tracer
from brainevent._data import JITCMatrix
from brainevent._event.binary import BinaryArray
from brainevent._typing import MatrixShape, WeightScalar, Prob, Seed
from .binary import binary_jitsmv, binary_jitsmm
from .float import jits, jitsmv, jitsmm
__all__ = [
'JITCScalarR',
'JITCScalarC',
]
class JITCScalarMatrix(JITCMatrix):
"""
Base class for Just-In-Time Connectivity Scalar-weight matrices.
This abstract class serves as the foundation for sparse matrix representations
that use homogeneous (scalar) weights with stochastic connectivity patterns. It
stores a single weight value applied to all non-zero elements, along with
connectivity probability and a random seed that determines the sparse structure.
Parameters
----------
weight : WeightScalar or Tuple[WeightScalar, Prob, Seed]
Either the homogeneous weight value for all non-zero elements,
or a tuple containing ``(weight, prob, seed)``.
prob : Prob, optional
Connection probability determining matrix sparsity.
seed : Seed, optional
Random seed for reproducible sparse structure generation.
shape : MatrixShape
The shape of the matrix as a tuple ``(rows, columns)``.
corder : bool, optional
Memory layout order flag, by default False.
backend : str or None, optional
The computation backend to use.
Returns
-------
JITCScalarMatrix
A new scalar-weight JIT connectivity matrix instance.
Raises
------
ValueError
If ``prob`` is not a scalar, is not finite, or is outside ``[0, 1]``.
See Also
--------
JITCScalarR : Row-oriented concrete subclass.
JITCScalarC : Column-oriented concrete subclass.
JITCMatrix : Parent class for all JIT connectivity matrices.
Notes
-----
The matrix ``W`` is defined by a scalar weight ``w``, a connection probability
``p``, and a deterministic pseudo-random seed ``s``. Each element is given by:
``W[i, j] = w * B[i, j]``
where ``B[i, j] ~ Bernoulli(p)`` is a binary mask whose realization is
fully determined by the seed ``s``. Specifically, the mask is generated by
a deterministic hash-based PRNG that, for a given ``(i, j, s)`` triple,
always produces the same binary outcome. This means:
- The same ``(weight, prob, seed, shape)`` always produces the identical matrix.
- The expected number of non-zeros per row is ``p * n_cols``.
- The matrix is never materialized in memory; it is regenerated on-the-fly
during each operation (matvec, matmat, todense).
The connection length parameter ``clen = 2 / p`` controls the average stride
between successive non-zero entries in the sampling loop.
Examples
--------
.. code-block:: python
>>> from brainevent import JITCScalarR
>>> mat = JITCScalarR((0.5, 0.1, 42), shape=(100, 50))
>>> mat.weight # 0.5
>>> mat.prob # 0.1
>>> mat.seed # 42
Attributes
----------
weight : Union[jax.Array, u.Quantity]
The homogeneous weight value applied to all non-zero elements in the matrix.
Can be a plain JAX array or a quantity with units.
prob : Union[float, jax.Array]
Connection probability determining the sparsity of the matrix.
Values range from 0 (no connections) to 1 (fully connected).
seed : Union[int, jax.Array]
Random seed controlling the specific pattern of connections.
Using the same seed produces identical connectivity patterns.
shape : MatrixShape
Tuple specifying the dimensions of the matrix as (rows, columns).
corder : bool
Flag indicating the memory layout order of the matrix.
False (default) for Fortran-order (column-major), True for C-order (row-major).
backend : str or None
The computation backend to use (e.g., ``'numba'``, ``'pallas'``). If ``None``,
the default backend is selected automatically.
"""
__module__ = 'brainevent'
weight: Union[jax.Array, u.Quantity]
prob: Union[float, jax.Array]
seed: Union[int, jax.Array]
shape: MatrixShape
corder: bool
backend: Optional[str]
def __init__(
self,
weight,
prob=None,
seed=None,
*,
shape: MatrixShape,
corder: bool = False,
backend: Optional[str] = None,
buffers: Optional[Dict] = None,
):
"""
Initialize a homogeneous sparse just-in-time connectivity matrix.
Parameters
----------
weight : WeightScalar or Tuple[WeightScalar, Prob, Seed]
Either the homogeneous weight value for all non-zero elements,
or a tuple containing (weight, prob, seed).
prob : Prob, optional
Connection probability determining matrix sparsity.
If None, ``weight`` is treated as a tuple of (weight, prob, seed).
seed : Seed, optional
Random seed for reproducible sparse structure generation.
shape : MatrixShape
The shape of the matrix as a tuple (rows, columns).
corder : bool, optional
Memory layout order flag, by default False.
- False: Fortran-order (column-major)
- True: C-order (row-major)
backend : str or None, optional
The computation backend to use. If ``None``, the default backend is
selected automatically.
Raises
------
ValueError
If ``prob`` is not a scalar, is not finite, or is outside [0, 1].
Notes
-----
The constructor extracts the components from the data tuple and sets them
as instance attributes. The weight is converted to a JAX array if it's not
already one, preserving any attached units.
"""
if prob is None and seed is None:
data = weight
else:
data = (weight, prob, seed)
weight, self.prob, self.seed = data
if not isinstance(self.prob, Tracer):
prob = np.asarray(self.prob)
if prob.size != 1:
raise ValueError(f"prob must be a scalar, but got shape {prob.shape}.")
prob = float(prob.item())
if not np.isfinite(prob):
raise ValueError(f"prob must be finite, but got {prob}.")
if not (0. <= prob <= 1.):
raise ValueError(f"prob must be in [0, 1], but got {prob}.")
self.weight = u.math.asarray(weight)
self.corder = corder
self.backend = backend
super().__init__(data, shape=shape, buffers=buffers)
def __repr__(self):
"""
Return a string representation of the homogeneous matrix.
Returns
-------
str
A string showing the class name, shape, weight value, probability,
seed, and corder flag of the matrix instance.
Examples
--------
>>> matrix = JITCScalarMatrix((0.5, 0.1, 42), shape=(10, 10))
>>> repr(matrix)
'JITHomoMatrix(shape=(10, 10), weight=0.5, prob=0.1, seed=42, corder=False)'
"""
return (
f"{self.__class__.__name__}("
f"shape={self.shape}, "
f"weight={self.weight}, "
f"prob={self.prob}, "
f"seed={self.seed}, "
f"corder={self.corder},"
f"backend={self.backend},"
f")"
)
@property
def dtype(self):
"""
Get the data type of the matrix elements.
Returns
-------
dtype
The data type of the weight values in the matrix.
Notes
-----
This property inherits the dtype directly from the weight attribute,
ensuring consistent data typing throughout operations involving this matrix.
"""
return self.weight.dtype
@property
def data(self) -> Tuple[WeightScalar, Prob, Seed]:
"""
Returns the core data components of the homogeneous matrix.
This property provides access to the three fundamental components that define
the sparse matrix: weight values, connection probabilities, and the random seed.
It's used by the tree_flatten method to make the class compatible with JAX
transformations.
Returns
-------
Tuple[Weight, Prob, Seed]
A tuple containing:
- weight: The homogeneous weight value for non-zero elements
- prob: Connection probability for the sparse structure
- seed: Random seed used for generating the sparse connectivity pattern
"""
return self.weight, self.prob, self.seed
def with_data(self, weight: WeightScalar):
"""
Create a new matrix instance with updated weight data but preserving other properties.
This method returns a new instance of the same class with the provided weight value,
while keeping the same probability, seed, shape, and other configuration parameters.
It's useful for updating weights without changing the connectivity pattern.
Parameters
----------
weight : Weight
The new weight value to use. Must have the same shape and unit as the current weight.
Returns
-------
JITCScalarMatrix
A new matrix instance of the same concrete type with the updated weight.
Raises
------
AssertionError
If the provided weight has a different shape or unit than the current weight.
"""
weight = u.math.asarray(weight)
assert weight.shape == self.weight.shape
assert u.get_unit(weight) == u.get_unit(self.weight)
return type(self)(
(weight, self.prob, self.seed),
shape=self.shape,
corder=self.corder,
backend=self.backend,
buffers=self.buffers,
)
def tree_flatten(self):
"""
Flatten the matrix into leaves and auxiliary data for JAX pytree registration.
Returns
-------
tuple
A 2-tuple where the first element is a tuple of JAX-traceable leaves
``(weight, prob, seed)`` and the second element is a dict of static
auxiliary data ``{'shape': ..., 'corder': ..., 'backend': ...}``.
See Also
--------
tree_unflatten : Reconstruct the matrix from flattened representation.
"""
aux = {'shape': self.shape, 'corder': self.corder, 'backend': self.backend}
return (self.weight, self.prob, self.seed), (aux, self.buffers)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""
Reconstruct a matrix instance from flattened pytree data.
Parameters
----------
aux_data : dict
Dictionary of static auxiliary data containing ``'shape'``, ``'corder'``,
and ``'backend'`` keys.
children : tuple
Tuple of JAX-traceable leaves ``(weight, prob, seed)``.
Returns
-------
JITCScalarMatrix
A reconstructed matrix instance with attributes restored from both
``children`` and ``aux_data``.
See Also
--------
tree_flatten : Flatten the matrix for JAX pytree operations.
"""
obj = object.__new__(cls)
obj.weight, obj.prob, obj.seed = children
aux_data, buffer = aux_data
obj._buffer_registry = set(buffer.keys())
for k, v in aux_data.items():
setattr(obj, k, v)
for k, v in buffer.items():
setattr(obj, k, v)
return obj
def _check(self, other, op):
"""
Validate compatibility between two JITScalarMatrix instances for binary operations.
Parameters
----------
other : JITCScalarMatrix
The other matrix to check compatibility against.
op : callable
The binary operation being attempted, used in error messages.
Raises
------
NotImplementedError
If the two matrices have different seeds, different corder flags,
different probabilities, or different shapes.
"""
if not (isinstance(other.seed, Tracer) and isinstance(self.seed, Tracer)):
if self.seed != other.seed:
raise NotImplementedError(
f"binary operation {op} between two {self.__class__.__name__} objects with different seeds "
f"is not implemented currently."
)
else:
raise NotImplementedError(
f"binary operation {op} between two {self.__class__.__name__} objects with tracing seeds "
f"is not implemented currently."
)
if self.corder != other.corder:
raise NotImplementedError(
f"binary operation {op} between two {self.__class__.__name__} objects with different corder "
f"is not implemented currently."
)
if self.prob != other.prob:
raise NotImplementedError(
f"binary operation {op} between two {self.__class__.__name__} objects "
f"with different prob is not supported."
)
if self.shape != other.shape:
raise NotImplementedError(
f"binary operation {op} between two {self.__class__.__name__} objects "
f"with different shapes is not supported."
)
@jax.tree_util.register_pytree_node_class
class JITCScalarR(JITCScalarMatrix):
"""
Just-In-Time Connectivity Homogeneous matrix with Row-oriented representation.
This class represents a row-oriented homogeneous sparse matrix optimized for JAX-based
transformations. It follows the Compressed Sparse Row (CSR) format conceptually, storing
a uniform weight value for all non-zero elements in the matrix, along with probability
and seed information to determine the sparse structure.
The class is designed for efficient neural network connectivity patterns where weights
are homogeneous (identical) but connectivity is sparse and stochastically determined.
The row-oriented structure makes row-based operations more efficient than column-based ones.
Attributes
----------
weight : Union[jax.Array, u.Quantity]
The homogeneous value used for all non-zero elements in the matrix.
Can be a plain JAX array or a quantity with units.
prob : Union[float, jax.Array]
Probability for each potential connection. Controls the sparsity level
with 0.0 meaning no connections and 1.0 meaning all possible connections.
seed : Union[int, jax.Array]
Random seed used for initialization of the sparse structure.
Using the same seed produces identical connectivity patterns.
shape : MatrixShape
The shape of the matrix as a tuple (rows, cols).
corder : bool
Flag indicating the memory layout order of the matrix.
False (default) for Fortran-order (column-major), True for C-order (row-major).
dtype
The data type of the matrix elements (property inherited from parent).
Examples
--------
.. code-block:: python
>>> import jax
>>> import brainunit as u
>>> from brainevent import JITCScalarR
# Create a homogeneous matrix with value 1.5, probability 0.1, and seed 42
>>> homo_matrix = JITCScalarR((1.5, 0.1, 42), shape=(10, 10))
>>> homo_matrix
JITCHomoR(shape=(10, 10), weight=1.5, prob=0.1, seed=42, corder=False)
# Create a matrix with units
>>> weighted_matrix = JITCScalarR((1.5 * u.mV, 0.1, 42), shape=(10, 10))
>>> weighted_matrix
JITCHomoR(shape=(10, 10), weight=1.5 mV, prob=0.1, seed=42, corder=False)
# Perform matrix-vector multiplication
>>> vec = jax.numpy.ones(10)
>>> result = homo_matrix @ vec
>>> result.shape # (10,)
# Apply scalar operations
>>> scaled = homo_matrix * 2.0
>>> scaled.weight # 3.0
# Arithmetic operations maintain the sparse structure
>>> neg_matrix = -homo_matrix
>>> neg_matrix.weight # -1.5
# Convert to dense representation
>>> dense_matrix = homo_matrix.todense()
>>> dense_matrix.shape # (10, 10)
# Transpose operation returns a column-oriented matrix
>>> col_matrix = homo_matrix.transpose()
>>> isinstance(col_matrix, JITCScalarC) # True
>>> col_matrix.shape # (10, 10)
Notes
-----
The mathematical model for this matrix is:
``W[i, j] = w * Bernoulli(p)``
where ``w`` is the scalar weight (``self.weight``), ``p`` is the connection
probability (``self.prob``), and the Bernoulli draw is determined by a
deterministic hash from the seed. The expected value of each element is
``E[W[i, j]] = w * p`` and the variance is ``Var[W[i, j]] = w^2 * p * (1 - p)``.
For a matrix-vector product ``y = W @ x``:
``y[i] = sum_{j in C(i)} w * x[j]``
where ``C(i)`` is the deterministic random connection set for row ``i``,
with ``|C(i)| ~ Binomial(n_cols, p)``.
Key properties:
- JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)
- More memory-efficient than dense matrices for sparse connectivity patterns
- Well-suited for neural network connectivity matrices with uniform weights
- Optimized for matrix-vector operations common in neural simulations
- The matrix is implicitly constructed based on the probability and seed;
the actual sparse structure is materialized only when needed
- When used with units (e.g., ``u.mV``), units are preserved through operations
See Also
--------
JITCScalarC : Column-oriented counterpart of this class.
JITCScalarMatrix : Base class providing shared functionality.
"""
__module__ = 'brainevent'
[docs]
def todense(self) -> Union[jax.Array, u.Quantity]:
"""
Convert the sparse scalar-weight matrix to dense format.
Generates a full dense representation of the sparse matrix by
materializing all entries ``W[i, j] = w * Bernoulli(p)`` determined by the
probability and seed.
Parameters
----------
None
Returns
-------
Union[jax.Array, u.Quantity]
A dense matrix with the same shape as the sparse matrix. The data type
will match the weight's data type, and if the weight has units (is a
``u.Quantity``), the returned array will have the same units.
Raises
------
None
See Also
--------
jits : The underlying function that materializes the matrix.
Notes
-----
The dense matrix is generated by iterating over all ``(i, j)`` positions
and placing the scalar weight ``w`` at each position where the
deterministic PRNG indicates a connection:
``dense[i, j] = w if hash(seed, i, j) < p else 0``
Examples
--------
.. code-block:: python
>>> import brainunit as u
>>> from brainevent import JITCScalarR
>>> sparse_matrix = JITCScalarR((1.5 * u.mV, 0.5, 42), shape=(10, 4))
>>> dense_matrix = sparse_matrix.todense()
>>> dense_matrix.shape # (10, 4)
"""
return jits(
self.weight,
self.prob,
self.seed,
shape=self.shape,
transpose=False,
corder=self.corder,
backend=self.backend,
)
[docs]
def transpose(self, axes=None) -> 'JITCScalarC':
"""
Transposes the row-oriented matrix into a column-oriented matrix.
This method returns a column-oriented matrix (JITCScalarC) with rows and columns
swapped, preserving the same weight, probability, and seed values.
The transpose operation effectively converts between row-oriented and
column-oriented sparse matrix formats.
Parameters
----------
axes : None
Not supported. This parameter exists for compatibility with the NumPy API
but only None is accepted.
Returns
-------
JITCScalarC
A new column-oriented homogeneous matrix with transposed dimensions.
Raises
------
AssertionError
If axes is not None, since partial axis transposition is not supported.
Examples
--------
>>> import jax
>>> import brainunit as u
>>> from brainevent import JITCScalarR
>>>
>>> # Create a row-oriented matrix
>>> row_matrix = JITCScalarR((1.5, 0.5, 42), shape=(30, 5))
>>> print(row_matrix.shape) # (30, 5)
>>>
>>> # Transpose to column-oriented matrix
>>> col_matrix = row_matrix.transpose()
>>> print(col_matrix.shape) # (5, 30)
>>> isinstance(col_matrix, JITCScalarC) # True
"""
assert axes is None, "transpose does not support axes argument."
return JITCScalarC(
(self.weight, self.prob, self.seed),
shape=(self.shape[1], self.shape[0]),
corder=not self.corder,
backend=self.backend,
buffers=self.buffers,
)
def _new_mat(self, weight, prob=None, seed=None):
"""
Create a new ``JITCScalarR`` with the given weight while preserving structure.
Parameters
----------
weight : jax.Array or u.Quantity
The new weight value.
prob : float or None, optional
Connection probability. If ``None``, uses ``self.prob``.
seed : int or None, optional
Random seed. If ``None``, uses ``self.seed``.
Returns
-------
JITCScalarR
A new row-oriented matrix with the updated weight.
"""
return JITCScalarR(
(
weight,
self.prob if prob is None else prob,
self.seed if seed is None else seed
),
shape=self.shape,
corder=self.corder,
backend=self.backend,
buffers=self.buffers,
)
def _unitary_op(self, op) -> 'JITCScalarR':
"""
Apply a unary operation to the weight of this matrix.
Parameters
----------
op : callable
A unary function to apply to the weight (e.g., ``operator.neg``).
Returns
-------
JITCScalarR
A new matrix with the transformed weight.
"""
return self._new_mat(op(self.weight), self.prob, self.seed)
def _binary_op(self, other, op) -> 'JITCScalarR':
"""
Apply a binary operation between this matrix and another operand.
Parameters
----------
other : JITCScalarR or u.sparse.SparseMatrix or scalar
The right-hand operand for the binary operation.
op : callable
A binary function (e.g., ``operator.mul``).
Returns
-------
JITCScalarR
A new matrix whose weight is ``op(self.weight, other_weight)``.
Raises
------
NotImplementedError
If ``other`` is a sparse matrix of a different type or a non-scalar array.
"""
if isinstance(other, JITCScalarR):
self._check(other, op)
return self._new_mat(op(self.weight, other.weight))
if isinstance(other, u.sparse.SparseMatrix):
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
other = u.math.asarray(other)
if other.size == 1:
return self._new_mat(op(self.weight, other))
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
def _binary_rop(self, other, op) -> 'JITCScalarR':
"""
Apply a reflected binary operation (other op self).
Parameters
----------
other : JITCScalarR or u.sparse.SparseMatrix or scalar
The left-hand operand for the binary operation.
op : callable
A binary function (e.g., ``operator.mul``).
Returns
-------
JITCScalarR
A new matrix whose weight is ``op(other_weight, self.weight)``.
Raises
------
NotImplementedError
If ``other`` is a sparse matrix of a different type or a non-scalar array.
"""
if isinstance(other, JITCScalarR):
self._check(other, op)
return self._new_mat(op(other.weight, self.weight))
if isinstance(other, u.sparse.SparseMatrix):
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
other = u.math.asarray(other)
if other.size == 1:
return self._new_mat(op(other, self.weight))
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
def __matmul__(self, other) -> Union[jax.Array, u.Quantity]:
"""
Compute ``self @ other`` (matrix-vector or matrix-matrix product).
Dispatches to the appropriate kernel depending on the type and
dimensionality of ``other``:
* 1-D ``BinaryArray`` -- event-driven matrix-vector product via
``binary_jitsmv``.
* 2-D ``BinaryArray`` -- event-driven matrix-matrix product via
``binary_jitsmm``.
* 1-D dense array -- float matrix-vector product via ``jitsmv``.
* 2-D dense array -- float matrix-matrix product via ``jitsmm``.
Parameters
----------
other : BinaryArray or jax.Array or u.Quantity
The right-hand operand. Must be 1-D (vector) or 2-D (matrix).
Returns
-------
Union[jax.Array, u.Quantity]
The result of the matrix multiplication. Units from both the
weight and the operand are preserved.
Raises
------
NotImplementedError
If ``other`` is a sparse matrix or has more than 2 dimensions.
"""
# csr @ other
if isinstance(other, u.sparse.SparseMatrix):
raise NotImplementedError("matmul between two sparse objects.")
weight = self.weight
if isinstance(other, BinaryArray):
other = other.value
if other.ndim == 1:
# JIT matrix @ events
return binary_jitsmv(weight, self.prob, other, self.seed, shape=self.shape,
transpose=False, corder=self.corder, backend=self.backend)
elif other.ndim == 2:
# JIT matrix @ events
return binary_jitsmm(weight, self.prob, other, self.seed, shape=self.shape,
transpose=False, corder=self.corder, backend=self.backend)
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
else:
other = u.math.asarray(other)
weight, other = u.math.promote_dtypes(self.weight, other)
if other.ndim == 1:
# JIT matrix @ vector
return jitsmv(
weight,
self.prob,
other,
self.seed,
shape=self.shape,
transpose=False,
corder=self.corder,
backend=self.backend,
)
elif other.ndim == 2:
# JIT matrix @ matrix
return jitsmm(
weight,
self.prob,
other,
self.seed,
shape=self.shape,
transpose=False,
corder=self.corder,
backend=self.backend,
)
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
def __rmatmul__(self, other) -> Union[jax.Array, u.Quantity]:
"""
Compute ``other @ self`` (right matrix multiplication).
This is equivalent to ``self.T @ other`` for vectors, or
``(self.T @ other.T).T`` for matrices.
Parameters
----------
other : BinaryArray or jax.Array or u.Quantity
The left-hand operand. Must be 1-D (vector) or 2-D (matrix).
Returns
-------
Union[jax.Array, u.Quantity]
The result of the matrix multiplication with units preserved.
Raises
------
NotImplementedError
If ``other`` is a sparse matrix or has more than 2 dimensions.
"""
# other @ csr
if isinstance(other, u.sparse.SparseMatrix):
raise NotImplementedError("matmul between two sparse objects.")
weight = self.weight
if isinstance(other, BinaryArray):
other = other.value
if other.ndim == 1:
#
# vector @ JIT matrix
# ==
# JIT matrix.T @ vector
#
return binary_jitsmv(
weight,
self.prob,
other,
self.seed,
shape=self.shape,
transpose=True,
corder=not self.corder,
backend=self.backend,
)
elif other.ndim == 2:
#
# matrix @ JIT matrix
# ==
# (JIT matrix.T @ matrix.T).T
#
r = binary_jitsmm(
weight,
self.prob,
other.T,
self.seed,
shape=self.shape,
transpose=True,
corder=not self.corder,
backend=self.backend,
)
return r.T
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
else:
other = u.math.asarray(other)
weight, other = u.math.promote_dtypes(self.weight, other)
if other.ndim == 1:
#
# vector @ JIT matrix
# ==
# JIT matrix.T @ vector
#
return jitsmv(
weight,
self.prob,
other,
self.seed,
shape=self.shape,
transpose=True,
corder=not self.corder, # This is import to generate the same matrix as ``.todense()``
backend=self.backend,
)
elif other.ndim == 2:
#
# matrix @ JIT matrix
# ==
# (JIT matrix.T @ matrix.T).T
#
r = jitsmm(
weight,
self.prob,
other.T,
self.seed,
shape=self.shape,
transpose=True,
corder=not self.corder, # This is import to generate the same matrix as ``.todense()``
backend=self.backend,
)
return r.T
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
@jax.tree_util.register_pytree_node_class
class JITCScalarC(JITCScalarMatrix):
"""
Just-In-Time Connectivity Homogeneous matrix with Column-oriented representation.
This class represents a column-oriented homogeneous sparse matrix optimized for JAX-based
transformations. It follows the Compressed Sparse Column (CSC) format conceptually, storing
a uniform weight value for all non-zero elements in the matrix, along with probability
and seed information to determine the sparse structure.
The column-oriented structure makes column-based operations more efficient than row-based
ones, making this class the transpose-oriented counterpart to JITCScalarR.
Attributes
----------
weight : Union[jax.Array, u.Quantity]
The homogeneous value used for all non-zero elements in the matrix.
Can be a plain JAX array or a quantity with units.
prob : Union[float, jax.Array]
Probability for each potential connection. Controls the sparsity level
with 0.0 meaning no connections and 1.0 meaning all possible connections.
seed : Union[int, jax.Array]
Random seed used for initialization of the sparse structure.
Using the same seed produces identical connectivity patterns.
shape : MatrixShape
The shape of the matrix as a tuple (rows, cols).
corder : bool
Flag indicating the memory layout order of the matrix.
False (default) for Fortran-order (column-major), True for C-order (row-major).
dtype
The data type of the matrix elements (property inherited from parent).
Examples
--------
.. code-block:: python
>>> import jax
>>> import brainunit as u
>>> from brainevent import JITCScalarC
# Create a homogeneous matrix with value 1.5, probability 0.1, and seed 42
>>> homo_matrix = JITCScalarC((1.5, 0.1, 42), shape=(10, 10))
>>> homo_matrix
JITCHomoC(shape=(10, 10), weight=1.5, prob=0.1, seed=42, corder=False)
# Create a matrix with units
>>> weighted_matrix = JITCScalarC((1.5 * u.mV, 0.1, 42), shape=(10, 10))
>>> weighted_matrix
JITCHomoC(shape=(10, 10), weight=1.5 mV, prob=0.1, seed=42, corder=False)
# Perform matrix-vector multiplication
>>> vec = jax.numpy.ones(10)
>>> result = homo_matrix @ vec
>>> result.shape # (10,)
# Apply scalar operations
>>> scaled = homo_matrix * 2.0
>>> scaled.weight # 3.0
# Arithmetic operations maintain the sparse structure
>>> neg_matrix = -homo_matrix
>>> neg_matrix.weight # -1.5
# Convert to dense representation
>>> dense_matrix = homo_matrix.todense()
>>> dense_matrix.shape # (10, 10)
# Transpose operation returns a row-oriented matrix
>>> row_matrix = homo_matrix.transpose()
>>> isinstance(row_matrix, JITCScalarR) # True
>>> row_matrix.shape # (10, 10)
Notes
-----
The mathematical model is the same as ``JITCScalarR``:
``W[i, j] = w * Bernoulli(p)``
where ``w`` is the scalar weight, ``p`` is the connection probability, and
the Bernoulli draw is fully determined by the seed. The column-oriented
representation means that ``JITCScalarC`` is conceptually the transpose of
a ``JITCScalarR`` matrix with swapped dimensions.
Key properties:
- JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)
- More memory-efficient than dense matrices for sparse connectivity patterns
- More efficient than ``JITCScalarR`` for column-based operations
- Well-suited for neural network connectivity matrices with uniform weights
- The matrix is implicitly constructed based on the probability and seed;
the actual sparse structure is materialized only when needed
- When used with units (e.g., ``u.mV``), units are preserved through operations
See Also
--------
JITCScalarR : Row-oriented counterpart of this class.
JITCScalarMatrix : Base class providing shared functionality.
"""
__module__ = 'brainevent'
[docs]
def todense(self) -> Union[jax.Array, u.Quantity]:
"""
Convert the sparse column-oriented scalar-weight matrix to dense format.
Generates a full dense representation of the sparse matrix by
materializing all entries ``W[i, j] = w * Bernoulli(p)`` determined by
the probability and seed. The generated dense matrix always has
``self.shape``.
Parameters
----------
None
Returns
-------
Union[jax.Array, u.Quantity]
A dense matrix with the same shape as the sparse matrix. The data type
will match the weight's data type, and if the weight has units (is a
``u.Quantity``), the returned array will have the same units.
Raises
------
None
See Also
--------
jits : The underlying function that materializes the matrix.
Notes
-----
The dense matrix is generated identically to ``JITCScalarR.todense()``:
``dense[i, j] = w if hash(seed, i, j) < p else 0``
Examples
--------
.. code-block:: python
>>> import brainunit as u
>>> from brainevent import JITCScalarC
>>> sparse_matrix = JITCScalarC((1.5 * u.mV, 0.5, 42), shape=(3, 10))
>>> dense_matrix = sparse_matrix.todense()
>>> dense_matrix.shape # (3, 10)
"""
return jits(
self.weight,
self.prob,
self.seed,
shape=self.shape,
transpose=False,
corder=self.corder,
backend=self.backend,
)
[docs]
def transpose(self, axes=None) -> 'JITCScalarR':
"""
Transposes the column-oriented matrix into a row-oriented matrix.
This method returns a row-oriented matrix (JITCScalarR) with rows and columns
swapped, preserving the same weight, probability, and seed values.
The transpose operation effectively converts between column-oriented and
row-oriented sparse matrix formats.
Parameters
----------
axes : None
Not supported. This parameter exists for compatibility with the NumPy API
but only None is accepted.
Returns
-------
JITCScalarR
A new row-oriented homogeneous matrix with transposed dimensions.
Raises
------
AssertionError
If axes is not None, since partial axis transposition is not supported.
Examples
--------
>>> import jax
>>> import brainunit as u
>>> from brainevent import JITCScalarC
>>>
>>> # Create a column-oriented matrix
>>> col_matrix = JITCScalarC((1.5, 0.5, 42), shape=(3, 5))
>>> print(col_matrix.shape) # (3, 5)
>>>
>>> # Transpose to row-oriented matrix
>>> row_matrix = col_matrix.transpose()
>>> print(row_matrix.shape) # (5, 3)
>>> isinstance(row_matrix, JITCScalarR) # True
"""
assert axes is None, "transpose does not support axes argument."
return JITCScalarR(
(self.weight, self.prob, self.seed),
shape=(self.shape[1], self.shape[0]),
corder=not self.corder,
backend=self.backend,
buffers=self.buffers,
)
def _new_mat(self, weight, prob=None, seed=None):
"""
Create a new ``JITCScalarC`` with the given weight while preserving structure.
Parameters
----------
weight : jax.Array or u.Quantity
The new weight value.
prob : float or None, optional
Connection probability. If ``None``, uses ``self.prob``.
seed : int or None, optional
Random seed. If ``None``, uses ``self.seed``.
Returns
-------
JITCScalarC
A new column-oriented matrix with the updated weight.
"""
return JITCScalarC(
(
weight,
self.prob if prob is None else prob,
self.seed if seed is None else seed
),
shape=self.shape,
corder=self.corder,
backend=self.backend,
buffers=self.buffers,
)
def _unitary_op(self, op) -> 'JITCScalarC':
"""
Apply a unary operation to the weight of this matrix.
Parameters
----------
op : callable
A unary function to apply to the weight (e.g., ``operator.neg``).
Returns
-------
JITCScalarC
A new matrix with the transformed weight.
"""
return self._new_mat(op(self.weight))
def _binary_op(self, other, op) -> 'JITCScalarC':
"""
Apply a binary operation between this matrix and another operand.
Parameters
----------
other : JITCScalarC or u.sparse.SparseMatrix or scalar
The right-hand operand for the binary operation.
op : callable
A binary function (e.g., ``operator.mul``).
Returns
-------
JITCScalarC
A new matrix whose weight is ``op(self.weight, other_weight)``.
Raises
------
NotImplementedError
If ``other`` is a sparse matrix of a different type or a non-scalar array.
"""
if isinstance(other, JITCScalarC):
self._check(other, op)
return self._new_mat(op(self.weight, other.weight))
if isinstance(other, u.sparse.SparseMatrix):
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
other = u.math.asarray(other)
if other.size == 1:
return self._new_mat(op(self.weight, other))
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
def _binary_rop(self, other, op) -> 'JITCScalarC':
"""
Apply a reflected binary operation (other op self).
Parameters
----------
other : JITCScalarC or u.sparse.SparseMatrix or scalar
The left-hand operand for the binary operation.
op : callable
A binary function (e.g., ``operator.mul``).
Returns
-------
JITCScalarC
A new matrix whose weight is ``op(other_weight, self.weight)``.
Raises
------
NotImplementedError
If ``other`` is a sparse matrix of a different type or a non-scalar array.
"""
if isinstance(other, JITCScalarC):
self._check(other, op)
return self._new_mat(op(other.weight, self.weight))
if isinstance(other, u.sparse.SparseMatrix):
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
other = u.math.asarray(other)
if other.size == 1:
return self._new_mat(op(other, self.weight))
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
def __matmul__(self, other) -> Union[jax.Array, u.Quantity]:
"""
Compute ``self @ other`` (matrix-vector or matrix-matrix product).
Since ``JITCScalarC`` is column-oriented, multiplication is implemented
by transposing to the row-oriented form internally:
* ``self @ v`` is computed as ``JITCScalarR.T @ v``.
* ``self @ B`` is computed as ``JITCScalarR.T @ B``.
Parameters
----------
other : BinaryArray or jax.Array or u.Quantity
The right-hand operand. Must be 1-D (vector) or 2-D (matrix).
Returns
-------
Union[jax.Array, u.Quantity]
The result of the matrix multiplication with units preserved.
Raises
------
NotImplementedError
If ``other`` is a sparse matrix or has more than 2 dimensions.
"""
# csr @ other
if isinstance(other, u.sparse.SparseMatrix):
raise NotImplementedError("matmul between two sparse objects.")
weight = self.weight
if isinstance(other, BinaryArray):
other = other.value
if other.ndim == 1:
# JITC_R matrix.T @ vector
# ==
# vector @ JITC_R matrix
return binary_jitsmv(
weight,
self.prob,
other,
self.seed,
shape=self.shape[::-1],
transpose=True,
corder=self.corder,
backend=self.backend,
)
elif other.ndim == 2:
# JITC_R matrix.T @ matrix
# ==
# (matrix.T @ JITC_R matrix).T
return binary_jitsmm(
weight,
self.prob,
other,
self.seed,
shape=self.shape[::-1],
transpose=True,
corder=self.corder,
backend=self.backend,
)
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
else:
other = u.math.asarray(other)
weight, other = u.math.promote_dtypes(self.weight, other)
if other.ndim == 1:
# JITC_R matrix.T @ vector
# ==
# vector @ JITC_R matrix
return jitsmv(
weight,
self.prob,
other,
self.seed,
shape=self.shape[::-1],
transpose=True,
corder=self.corder,
backend=self.backend,
)
elif other.ndim == 2:
# JITC_R matrix.T @ matrix
# ==
# (matrix.T @ JITC_R matrix).T
return jitsmm(
weight,
self.prob,
other,
self.seed,
shape=self.shape[::-1],
transpose=True,
corder=self.corder,
backend=self.backend,
)
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
def __rmatmul__(self, other) -> Union[jax.Array, u.Quantity]:
"""
Compute ``other @ self`` (right matrix multiplication).
Since ``JITCScalarC`` is the transpose of a row-oriented matrix,
``other @ self`` is implemented via the underlying row-oriented
kernels:
* ``v @ self`` is computed as ``JITCScalarR @ v``.
* ``B @ self`` is computed as ``(JITCScalarR @ B.T).T``.
Parameters
----------
other : BinaryArray or jax.Array or u.Quantity
The left-hand operand. Must be 1-D (vector) or 2-D (matrix).
Returns
-------
Union[jax.Array, u.Quantity]
The result of the matrix multiplication with units preserved.
Raises
------
NotImplementedError
If ``other`` is a sparse matrix or has more than 2 dimensions.
"""
# other @ csr
if isinstance(other, u.sparse.SparseMatrix):
raise NotImplementedError("matmul between two sparse objects.")
weight = self.weight
if isinstance(other, BinaryArray):
other = other.value
if other.ndim == 1:
#
# vector @ JITC_R matrix.T
# ==
# JITC_R matrix @ vector
#
return binary_jitsmv(
weight,
self.prob,
other,
self.seed,
shape=self.shape[::-1],
transpose=False,
corder=not self.corder,
backend=self.backend,
)
elif other.ndim == 2:
#
# matrix @ JITC_R matrix.T
# ==
# (JITC_R matrix @ matrix.T).T
#
r = binary_jitsmm(
weight,
self.prob,
other.T,
self.seed,
shape=self.shape[::-1],
transpose=False,
corder=not self.corder,
backend=self.backend,
)
return r.T
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")
else:
other = u.math.asarray(other)
weight, other = u.math.promote_dtypes(self.weight, other)
if other.ndim == 1:
#
# vector @ JITC_R matrix.T
# ==
# JITC_R matrix @ vector
#
return jitsmv(
weight,
self.prob,
other,
self.seed,
shape=self.shape[::-1],
transpose=False,
corder=not self.corder,
backend=self.backend,
)
elif other.ndim == 2:
#
# matrix @ JITC_R matrix.T
# ==
# (JITC_R matrix @ matrix.T).T
#
r = jitsmm(
weight,
self.prob,
other.T,
self.seed,
shape=self.shape[::-1],
transpose=False,
corder=not self.corder,
backend=self.backend,
)
return r.T
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")