Installation#

BrainState runs anywhere JAX runs. Install the variant that matches your hardware; the only difference is which JAX backend is pulled in.

Install#

pip install -U brainstate[cpu]
pip install -U brainstate[cuda12]   # CUDA 12
pip install -U brainstate[cuda13]   # CUDA 13
pip install -U brainstate[tpu]

The extras ([cpu], [cuda12], [cuda13], [tpu]) select the appropriate jax/jaxlib build. Plain pip install -U brainstate installs the library against whatever JAX is already present, which is convenient if you manage JAX yourself.

BrainState needs Python 3.10 or newer.

The wider ecosystem#

BrainState is one component of a brain-modeling ecosystem. To install it together with the compatible companion packages — brainunit for physical units, braintools for optimizers, metrics and initializers, brainpy for dynamical-systems modeling — install the bundle:

pip install -U BrainX

From source#

To track the development version or contribute, install from the repository in editable mode:

git clone https://github.com/chaobrain/brainstate.git
cd brainstate
pip install -e .

Verify#

Confirm the installation and check which backend JAX selected:

python -c "import brainstate; print(brainstate.__version__)"
import jax
print(jax.devices())   # e.g. [CpuDevice(id=0)] or [CudaDevice(id=0)]

If jax.devices() reports a CPU device on a GPU machine, the CPU build of jaxlib is installed; reinstall with the matching CUDA extra above.

Next steps#