brainevent.load_cuda_inline

brainevent.load_cuda_inline#

brainevent.load_cuda_inline(name, cuda_sources, functions=None, *, extra_cuda_cflags=None, extra_ldflags=None, extra_include_paths=None, build_directory=None, verbose=False, compute_capability=None, force_rebuild=False, auto_register=True, target_prefix=None, ninja_workers=None, optimization_level=3, use_fast_math=False, allow_cuda_graph=True)[source]#

Compile inline CUDA source and load the resulting module.

Parameters:
  • name (str) – Module name (used for caching and FFI target naming).

  • cuda_sources (str | list[str]) – CUDA C++ source code. Multiple strings are concatenated.

  • functions (dict[str, list[str]] | None) –

    Mapping from function name to its arg_spec token list. Example: {"vector_add": ["arg", "arg", "ret", "stream"]}

    If None, functions are discovered from // @BE function_name annotations in the source code. The arg_spec is auto-inferred from the C++ signature.

  • extra_cuda_cflags (list[str] | None) – Additional compilation/linking flags.

  • extra_ldflags (list[str] | None) – Additional compilation/linking flags.

  • extra_include_paths (list[str] | None) – Additional compilation/linking flags.

  • build_directory (str | None) – Override the build directory.

  • verbose (bool) – Print detailed compilation output.

  • compute_capability (str | None) – GPU architecture (e.g. "sm_86"). Auto-detected if None.

  • force_rebuild (bool) – Skip cache and recompile.

  • auto_register (bool) – Automatically register each function as a JAX FFI target with name "<target_prefix>.<func_name>" (or "<name>.<func>" if target_prefix is None).

  • target_prefix (str | None) – Prefix for auto-registered FFI target names.

  • ninja_workers (int | None) – Number of parallel ninja workers (default: all CPUs).

  • optimization_level (int) – Compiler optimization level passed as -O<n> to nvcc (0–3). Applies to both host code and device PTX generation. Default: 3.

  • use_fast_math (bool) – Pass --use_fast_math to nvcc. Enables approximate division/sqrt, flush-to-zero for denormals, and fused multiply-add. Can give 10–30 % speed-up on FP-heavy kernels at the cost of reduced IEEE precision. Default: False.

  • allow_cuda_graph (bool) – Register kernels with the COMMAND_BUFFER_COMPATIBLE XLA trait so they can be captured and replayed by JAX’s CUDA-graph optimisation. Eliminates per-call CPU launch overhead inside jax.lax.fori_loop or repeated jax.jit calls. Set to False only for kernels with host-side side effects during replay. Default: True.

Return type:

CompiledModule