brainevent.csr_slice_rows_p#
- brainevent.csr_slice_rows_p = <brainevent.XLACustomKernel object>#
Low-level XLA custom-kernel primitive for
csr_slice_rows.This
XLACustomKernelinstance dispatches the CSR row slicing operation to registered backends (numba,pallas), using runtime shape/dtype metadata provided by the high-level wrapper.Extracts selected rows from a CSR sparse matrix and returns a dense submatrix of shape
(num_selected, n_cols). Each selected row is independently gathered: for row index r, non-zero entries indata[indptr[r]:indptr[r+1]]are scattered into the corresponding columns of the output. Out-of-bounds row indices produce zero rows.The operation is linear in
data, so forward-mode (JVP) simply applies the same slice to the tangent. Reverse-mode (transpose) calls the companion gradient primitivecsr_slice_rows_grad_pto gather cotangent contributions back into a vector of shape(nnz,).When
row_indicesis the only batched argument (e.g. undervmap), the batching rule flattens the batch of row index arrays into a single kernel call and reshapes the output, avoiding a sequential scan.Available backends can be queried with
csr_slice_rows_p.available_backends(platform), and the default backend can be configured withcsr_slice_rows_p.set_default(platform, backend).See also
csr_slice_rowsHigh-level user-facing function wrapper.
csr_slice_rows_grad_pCompanion gradient primitive used by the transpose rule.