BrainX Ecosystem
04 — Capability

Powered by JAX

JIT-compiled, GPU/TPU accelerated, composable function transforms — the JAX compiler stack for neuroscience.

Diagram showing JAX as the foundation: XLA compiler, GPU/TPU acceleration, function transforms.
BrainX rides on JAX's compiler infrastructure.

Built on JAX, BrainX leverages high-performance numerical computing and automatic differentiation to enable efficient brain modeling at scale.

  • JIT compilation and XLA fusion for high-performance numerical simulation
  • GPU and TPU acceleration for large-scale neural networks
  • Composable transforms that let brain models plug into JAX AI frameworks

Why JAX, specifically

Other frameworks could provide GPU acceleration. JAX gives us three things they don't: composable transforms (jit, grad, vmap, scan) that compose cleanly with each other; XLA, an optimizing compiler that fuses operators; and a Python-first API that tracks the way computational neuroscientists actually write models.

Learn more about JAX →