brainevent.csr_slice_rows_p

Contents

brainevent.csr_slice_rows_p#

brainevent.csr_slice_rows_p = <brainevent.XLACustomKernel object>#

Low-level XLA custom-kernel primitive for csr_slice_rows.

This XLACustomKernel instance dispatches the CSR row slicing operation to registered backends (numba, pallas), using runtime shape/dtype metadata provided by the high-level wrapper.

Extracts selected rows from a CSR sparse matrix and returns a dense submatrix of shape (num_selected, n_cols). Each selected row is independently gathered: for row index r, non-zero entries in data[indptr[r]:indptr[r+1]] are scattered into the corresponding columns of the output. Out-of-bounds row indices produce zero rows.

The operation is linear in data, so forward-mode (JVP) simply applies the same slice to the tangent. Reverse-mode (transpose) calls the companion gradient primitive csr_slice_rows_grad_p to gather cotangent contributions back into a vector of shape (nnz,).

When row_indices is the only batched argument (e.g. under vmap), the batching rule flattens the batch of row index arrays into a single kernel call and reshapes the output, avoiding a sequential scan.

Available backends can be queried with csr_slice_rows_p.available_backends(platform), and the default backend can be configured with csr_slice_rows_p.set_default(platform, backend).

See also

csr_slice_rows

High-level user-facing function wrapper.

csr_slice_rows_grad_p

Companion gradient primitive used by the transpose rule.