JIT Compilation#

brainstate.transform.jit extends jax.jit with state tracking and extra control surfaces. This guide highlights how BrainState JIT differs from plain JAX JIT, when to prefer each API, and how to decompose modules with brainstate.graph.treefy_split or brainstate.graph.treefy_states.

import jax
import jax.numpy as jnp

import brainstate

Why BrainState JIT?#

brainstate.transform.jit understands State objects and automatically wires read/write traces into the compiled function. The returned object is a JittedFunction with helper methods such as compile, clear_cache, and origin_fun. Pure functions still work, but stateful modules are first-class citizens.

@brainstate.transform.jit
def softplus(x: jax.Array) -> jax.Array:
    return jnp.log1p(jnp.exp(-jnp.abs(x))) + jnp.maximum(x, 0)

xs = jnp.linspace(-5.0, 5.0, 7)
softplus(xs)
Array([0.00671535, 0.03505242, 0.17300805, 0.69314724, 1.839675  ,
       3.368386  , 5.0067153 ], dtype=float32)

Subsequent calls reuse the compiled executable. If you disable JIT globally (jax.config.jax_disable_jit = True), BrainState falls back to the original Python implementation automatically.

with jax.disable_jit():
    softplus(xs * 2.0)

Stateful modules with zero boilerplate#

BrainState keeps modules stateful inside compiled code. Below, a running-mean tracker updates hidden state at each call without any manual intervention.

class RunningMean(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.sum = brainstate.HiddenState(jnp.array(0.0))
        self.count = brainstate.HiddenState(jnp.array(0))

    def __call__(self, batch: jax.Array) -> jax.Array:
        self.sum.value += jnp.sum(batch)
        self.count.value += batch.size
        return self.sum.value / self.count.value


tracker = RunningMean()

@brainstate.transform.jit
def update_running_mean(batch: jax.Array) -> jax.Array:
    return tracker(batch)

for step in range(3):
    data = jnp.arange(4.0) + step
    print(f'step {step}: mean={float(update_running_mean(data)):.2f}')

float(tracker.sum.value), int(tracker.count.value)
step 0: mean=1.50
step 1: mean=2.00
step 2: mean=2.50
(30.0, 12)

The hidden states remain in sync because BrainState records and replays the state updates around the compiled executable.

Extra controls exposed by JittedFunction#

Unlike bare jax.jit, BrainState’s wrapper exposes runtime helpers. You can precompile executables or drop cached traces explicitly.

softplus.compile(jnp.ones((4,)))
softplus(jnp.ones((4,)))
Array([1.3132617, 1.3132617, 1.3132617, 1.3132617], dtype=float32)
softplus.clear_cache()
softplus(jnp.linspace(-1.0, 1.0, 5))
Array([0.3132617, 0.474077 , 0.6931472, 0.974077 , 1.3132617], dtype=float32)

Working directly with jax.jit#

If you prefer raw JAX primitives you can still make modules jit-friendly by splitting them into pure stateless functions. brainstate.graph.treefy_split returns a GraphDef plus one or more state trees that you must thread manually.

model = RunningMean()

graph_def, hidden_state_tree = brainstate.graph.treefy_split(model, brainstate.HiddenState)


def running_mean_stateless(state_tree, batch):
    module = brainstate.graph.treefy_merge(graph_def, state_tree)
    out = module(batch)
    new_state_tree = brainstate.graph.treefy_states(module, brainstate.HiddenState)
    return out, new_state_tree


jax_jitted = jax.jit(running_mean_stateless)

state_tree = hidden_state_tree
for step in range(3):
    batch = jnp.arange(4.0) + step
    mean, state_tree = jax_jitted(state_tree, batch)
    print(f'step {step}: mean={float(mean):.2f}')

int(state_tree['count'].value), float(state_tree['sum'].value)
step 0: mean=1.50
step 1: mean=2.00
step 2: mean=2.50
(12, 30.0)

The JAX version works, but you are responsible for threading state containers and reconstructing modules yourself.

treefy_split vs treefy_states#

Both helpers live in brainstate.graph but serve different purposes:

  • treefy_split → returns (graph_def, state_tree1, state_tree2, ...). Use it when you need to rebuild modules (e.g. JAX interop or serialising complete graphs).

  • treefy_states → returns one or more state trees without the graph definition. It’s the lightweight choice when you only need a PyTree of parameters for optimisation or checkpointing.

See also BrainState Graph and Node System for more details of how to use these interfaces.

class TinyLinear(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = brainstate.ParamState(jnp.array([[1.0]]))
        self.bias = brainstate.ParamState(jnp.array([0.0]))

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.weight.value + self.bias.value


lin = TinyLinear()

# Split into graph + states (useful for reconstruction / JAX interop)
lin_graph, param_tree, other_states = brainstate.graph.treefy_split(
    lin, brainstate.ParamState, ...,
)
print('treefy_split param paths:', list(param_tree.to_flat().keys()))

# Fetch only the parameter tree (perfect for gradient updates)
params_only = brainstate.graph.treefy_states(lin, brainstate.ParamState)
print('treefy_states param paths:', list(params_only.to_flat().keys()))
treefy_split param paths: [('bias',), ('weight',)]
treefy_states param paths: [('bias',), ('weight',)]
# Example: compute gradients w.r.t. ParamState using brainstate.transform.grad
def mse_loss(params, others, x):
    lin_recovered = brainstate.graph.treefy_merge(lin_graph, params, others)
    pred = lin_recovered(x)
    target = 2.0 * x + 1.0
    return jnp.mean((pred - target) ** 2)

loss_grad = jax.value_and_grad(mse_loss)

(loss_value, grads) = loss_grad(param_tree, other_states, jnp.array([[0.0], [1.0]]))
print('loss:', float(loss_value))
for path, g in grads.items():
    print('grad', path, g)
loss: 2.5
grad bias TreefyState(
  type=<class 'brainstate.ParamState'>,
  value=Array([-3.], dtype=float32),
  tag=None
)
grad weight TreefyState(
  type=<class 'brainstate.ParamState'>,
  value=Array([[-2.]], dtype=float32),
  tag=None
)

treefy_states drops directly into optimisation pipelines: you obtain a PyTree keyed by parameter paths without carrying the GraphDef unless you plan to reconstruct the module elsewhere.

Static arguments still apply#

Static-argument handling mirrors jax.jit. The example below specialises the compiled program by polynomial degree.

@brainstate.transform.jit(static_argnums=1)
def polynomial_series(x: jax.Array, degree: int) -> jax.Array:
    powers = [x ** (i + 1) for i in range(degree)]
    coeffs = jnp.arange(1, degree + 1, dtype=x.dtype)
    return jnp.tensordot(coeffs, jnp.stack(powers, axis=0), axes=1)


p1 = polynomial_series(jnp.array([1.0, 2.0]), 3)
p2 = polynomial_series(jnp.array([1.0, 2.0]), 3)
p3 = polynomial_series(jnp.array([1.0, 2.0]), 4)
print(p1, p2, p3)
[ 6. 34.] [ 6. 34.] [10. 98.]

Which API should you choose?#

Scenario

brainstate.transform.jit

jax.jit

Stateful BrainState modules

✅ Zero boilerplate

⚠️ Requires treefy_split and manual state threading

Pure stateless functions

✅ Works (with helper methods)

✅ Often the leanest choice

Need compile() / clear_cache()

✅ Built-in

❌ Not available

Custom sharding / device placement

✅ Same signature as jax.jit

treefy_split is the workhorse when you need a GraphDef for reconstruction or JAX interop. treefy_states is the light option for extracting parameter PyTrees, for example before calling brainstate.transform.grad or saving a checkpoint.

Summary#

  • brainstate.transform.jit tracks BrainState State objects automatically and returns a JittedFunction with extra controls.

  • jax.jit still works, but you must explicitly split and merge module state.

  • graph.treefy_split produces (graph_def, state_tree1, state_tree2, …) for reconstruction; graph.treefy_states returns just the requested state trees.

  • Choose the interface that matches your workflow: use BrainState JIT for module-centric code, drop down to JAX primitives when integrating with other systems.