JIT Compilation#
brainstate.transform.jit extends jax.jit with state tracking and extra control
surfaces. This guide highlights how BrainState JIT differs from plain JAX JIT,
when to prefer each API, and how to decompose modules with
brainstate.graph.treefy_split or brainstate.graph.treefy_states.
import jax
import jax.numpy as jnp
import brainstate
Why BrainState JIT?#
brainstate.transform.jit understands State objects and automatically wires
read/write traces into the compiled function. The returned object is a
JittedFunction with helper methods such as compile, clear_cache, and
origin_fun. Pure functions still work, but stateful modules are first-class
citizens.
@brainstate.transform.jit
def softplus(x: jax.Array) -> jax.Array:
return jnp.log1p(jnp.exp(-jnp.abs(x))) + jnp.maximum(x, 0)
xs = jnp.linspace(-5.0, 5.0, 7)
softplus(xs)
Array([0.00671535, 0.03505242, 0.17300805, 0.69314724, 1.839675 ,
3.368386 , 5.0067153 ], dtype=float32)
Subsequent calls reuse the compiled executable. If you disable JIT globally
(jax.config.jax_disable_jit = True), BrainState falls back to the original
Python implementation automatically.
with jax.disable_jit():
softplus(xs * 2.0)
Stateful modules with zero boilerplate#
BrainState keeps modules stateful inside compiled code. Below, a running-mean tracker updates hidden state at each call without any manual intervention.
class RunningMean(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.sum = brainstate.HiddenState(jnp.array(0.0))
self.count = brainstate.HiddenState(jnp.array(0))
def __call__(self, batch: jax.Array) -> jax.Array:
self.sum.value += jnp.sum(batch)
self.count.value += batch.size
return self.sum.value / self.count.value
tracker = RunningMean()
@brainstate.transform.jit
def update_running_mean(batch: jax.Array) -> jax.Array:
return tracker(batch)
for step in range(3):
data = jnp.arange(4.0) + step
print(f'step {step}: mean={float(update_running_mean(data)):.2f}')
float(tracker.sum.value), int(tracker.count.value)
step 0: mean=1.50
step 1: mean=2.00
step 2: mean=2.50
(30.0, 12)
The hidden states remain in sync because BrainState records and replays the state updates around the compiled executable.
Extra controls exposed by JittedFunction#
Unlike bare jax.jit, BrainState’s wrapper exposes runtime helpers. You can
precompile executables or drop cached traces explicitly.
softplus.compile(jnp.ones((4,)))
softplus(jnp.ones((4,)))
Array([1.3132617, 1.3132617, 1.3132617, 1.3132617], dtype=float32)
softplus.clear_cache()
softplus(jnp.linspace(-1.0, 1.0, 5))
Array([0.3132617, 0.474077 , 0.6931472, 0.974077 , 1.3132617], dtype=float32)
Working directly with jax.jit#
If you prefer raw JAX primitives you can still make modules jit-friendly by
splitting them into pure stateless functions. brainstate.graph.treefy_split
returns a GraphDef plus one or more state trees that you must thread manually.
model = RunningMean()
graph_def, hidden_state_tree = brainstate.graph.treefy_split(model, brainstate.HiddenState)
def running_mean_stateless(state_tree, batch):
module = brainstate.graph.treefy_merge(graph_def, state_tree)
out = module(batch)
new_state_tree = brainstate.graph.treefy_states(module, brainstate.HiddenState)
return out, new_state_tree
jax_jitted = jax.jit(running_mean_stateless)
state_tree = hidden_state_tree
for step in range(3):
batch = jnp.arange(4.0) + step
mean, state_tree = jax_jitted(state_tree, batch)
print(f'step {step}: mean={float(mean):.2f}')
int(state_tree['count'].value), float(state_tree['sum'].value)
step 0: mean=1.50
step 1: mean=2.00
step 2: mean=2.50
(12, 30.0)
The JAX version works, but you are responsible for threading state containers and reconstructing modules yourself.
treefy_split vs treefy_states#
Both helpers live in brainstate.graph but serve different purposes:
treefy_split→ returns(graph_def, state_tree1, state_tree2, ...). Use it when you need to rebuild modules (e.g. JAX interop or serialising complete graphs).treefy_states→ returns one or more state trees without the graph definition. It’s the lightweight choice when you only need a PyTree of parameters for optimisation or checkpointing.
See also BrainState Graph and Node System for more details of how to use these interfaces.
class TinyLinear(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.weight = brainstate.ParamState(jnp.array([[1.0]]))
self.bias = brainstate.ParamState(jnp.array([0.0]))
def __call__(self, x: jax.Array) -> jax.Array:
return x @ self.weight.value + self.bias.value
lin = TinyLinear()
# Split into graph + states (useful for reconstruction / JAX interop)
lin_graph, param_tree, other_states = brainstate.graph.treefy_split(
lin, brainstate.ParamState, ...,
)
print('treefy_split param paths:', list(param_tree.to_flat().keys()))
# Fetch only the parameter tree (perfect for gradient updates)
params_only = brainstate.graph.treefy_states(lin, brainstate.ParamState)
print('treefy_states param paths:', list(params_only.to_flat().keys()))
treefy_split param paths: [('bias',), ('weight',)]
treefy_states param paths: [('bias',), ('weight',)]
# Example: compute gradients w.r.t. ParamState using brainstate.transform.grad
def mse_loss(params, others, x):
lin_recovered = brainstate.graph.treefy_merge(lin_graph, params, others)
pred = lin_recovered(x)
target = 2.0 * x + 1.0
return jnp.mean((pred - target) ** 2)
loss_grad = jax.value_and_grad(mse_loss)
(loss_value, grads) = loss_grad(param_tree, other_states, jnp.array([[0.0], [1.0]]))
print('loss:', float(loss_value))
for path, g in grads.items():
print('grad', path, g)
loss: 2.5
grad bias TreefyState(
type=<class 'brainstate.ParamState'>,
value=Array([-3.], dtype=float32),
tag=None
)
grad weight TreefyState(
type=<class 'brainstate.ParamState'>,
value=Array([[-2.]], dtype=float32),
tag=None
)
treefy_states drops directly into optimisation pipelines: you obtain a PyTree
keyed by parameter paths without carrying the GraphDef unless you plan to
reconstruct the module elsewhere.
Static arguments still apply#
Static-argument handling mirrors jax.jit. The example below specialises the
compiled program by polynomial degree.
@brainstate.transform.jit(static_argnums=1)
def polynomial_series(x: jax.Array, degree: int) -> jax.Array:
powers = [x ** (i + 1) for i in range(degree)]
coeffs = jnp.arange(1, degree + 1, dtype=x.dtype)
return jnp.tensordot(coeffs, jnp.stack(powers, axis=0), axes=1)
p1 = polynomial_series(jnp.array([1.0, 2.0]), 3)
p2 = polynomial_series(jnp.array([1.0, 2.0]), 3)
p3 = polynomial_series(jnp.array([1.0, 2.0]), 4)
print(p1, p2, p3)
[ 6. 34.] [ 6. 34.] [10. 98.]
Which API should you choose?#
Scenario |
|
|
|---|---|---|
Stateful BrainState modules |
✅ Zero boilerplate |
⚠️ Requires |
Pure stateless functions |
✅ Works (with helper methods) |
✅ Often the leanest choice |
Need |
✅ Built-in |
❌ Not available |
Custom sharding / device placement |
✅ Same signature as |
✅ |
treefy_split is the workhorse when you need a GraphDef for reconstruction or
JAX interop. treefy_states is the light option for extracting parameter
PyTrees, for example before calling brainstate.transform.grad or saving a
checkpoint.
Summary#
brainstate.transform.jittracks BrainStateStateobjects automatically and returns aJittedFunctionwith extra controls.jax.jitstill works, but you must explicitly split and merge module state.graph.treefy_splitproduces(graph_def, state_tree1, state_tree2, …)for reconstruction;graph.treefy_statesreturns just the requested state trees.Choose the interface that matches your workflow: use BrainState JIT for module-centric code, drop down to JAX primitives when integrating with other systems.