Utility Functions#

Index Conversion#

csr_to_coo_index(indptr, indices)

Convert CSR format index arrays to COO format index arrays.

coo_to_csc_index(pre_ids, indices, *, shape)

Convert COO format index arrays to CSC format.

csr_to_csc_index(csr_indptr, csr_indices, *, ...)

Convert CSR format index arrays to CSC format.

GPU/TPU Random Number Generator#

PallasLFSR88RNG(seed)

Combined LFSR random number generator by L'Ecuyer (LFSR88).

PallasLFSR113RNG(seed)

Combined LFSR random number generator by L'Ecuyer (LFSR113).

PallasLFSR128RNG(seed)

Combined LFSR random number generator (LFSR128).

PallasLFSRRNG(seed)

Factory: create a Pallas RNG instance using the globally configured algorithm.

get_pallas_lfsr_rng_class()

Return the Pallas RNG class for the current global LFSR algorithm.

Numba RNG LFSR88#

Numba RNG LFSR113#

Numba RNG LFSR128#

Numba RNG Dispatch#

Used internally by JIT kernel generators to resolve the correct LFSR functions at kernel-generation time.

Benchmarking#

BenchmarkResult(records[, primitive_name])

Unified container for benchmark timing records across all (config × backend) pairs.

benchmark_function(fn, n_warmup, n_runs[, ...])

Benchmark a function and return timing statistics.

Kernel Helpers#

defjvp(primitive, *jvp_rules)

Define per-input JVP rules for a JAX primitive.

general_batching_rule(prim, args, axes, **kwargs)

General-purpose batching rule for custom JAX primitives.

jaxtype_to_warptype(dtype)

Convert a JAX / NumPy dtype to the corresponding Warp scalar type.

jaxinfo_to_warpinfo(jax_info)

Convert a jax.ShapeDtypeStruct to a Warp array type descriptor.