IR Optimization and Code Generation#
Every BrainState transformation ultimately lowers your code to a jaxpr — JAX’s typed intermediate representation. BrainState exposes that IR and a small toolkit for working with it: inspect the computation graph, run optimization passes over it, and regenerate readable Python from it. These tools are useful for understanding what the compiler sees, verifying that state reads and writes are tracked correctly, and squeezing redundant work out of a hot path.
import jax.numpy as jnp
import brainstate
import brainstate.transform as T
brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'
Inspecting the IR with make_jaxpr#
make_jaxpr(fn)(args) traces fn and returns its ClosedJaxpr together with the tuple of
States it touched. Unlike jax.make_jaxpr, it is state-aware: state reads appear as extra
inputs and state writes as extra outputs of the jaxpr.
def pure(x):
return jnp.sum(x ** 2)
jaxpr, states = T.make_jaxpr(pure)(jnp.array([1.0, 2.0, 3.0]))
print('states touched:', len(states))
print(jaxpr)
states touched: 0
{ lambda ; a:f32[3]. let
b:f32[3] = integer_pow[y=2] a
c:f32[] = reduce_sum[axes=(0,) out_sharding=None] b
in (c,) }
With a stateful function the difference is visible: the state value enters as an input and the
updated value leaves as an output, making the data flow through State explicit.
counter = brainstate.State(jnp.array(0.0))
def stateful(x):
counter.value = counter.value + x
return counter.value * 2
jaxpr, states = T.make_jaxpr(stateful)(jnp.array(1.0))
print('states touched:', len(states))
print(jaxpr)
states touched: 1
{ lambda ; a:f32[] b:f32[]. let
c:f32[] = add b a
d:f32[] = mul c 2.0:f32[]
in (d, c) }
Optimizing a jaxpr#
optimize_jaxpr applies classic compiler passes to a jaxpr. The available passes are:
Pass |
Effect |
|---|---|
|
dead-code elimination — drop equations whose outputs are unused |
|
common-subexpression elimination — reuse repeated computations |
|
evaluate operations on known constants ahead of time |
|
apply identities such as |
|
remove redundant copies |
The function below contains deliberate waste: an unused product and a multiply-by-one.
def wasteful(x):
a = x + 1.0
unused = x * 999.0 # dead code
scaled = a * 1.0 # algebraic identity
return scaled + a
jaxpr, _ = T.make_jaxpr(wasteful)(jnp.array([1.0, 2.0]))
before = len(jaxpr.jaxpr.eqns)
optimized = T.optimize_jaxpr(jaxpr, optimizations=['dce', 'algebraic_simplification', 'cse'])
after = len(optimized.jaxpr.eqns)
print(f'equations: {before} -> {after}')
print(optimized)
equations: 4 -> 2
{ lambda ; a:f32[2]. let
b:f32[2] = add a 1.0:f32[]
c:f32[2] = add b b
in (c,) }
Generating Python from a jaxpr#
jaxpr_to_python_code turns a jaxpr back into readable Python source — a useful way to see the
effect of an optimization pass. fn_to_python_code is the one-step convenience that traces a
function and prints its generated code directly.
print(T.jaxpr_to_python_code(optimized.jaxpr, fn_name='optimized_fn'))
def optimized_fn(a):
b = a + 1.0
c = b + b
return c
print(T.fn_to_python_code(pure, jnp.array([1.0, 2.0, 3.0])))
def pure(a):
b = jax.lax.integer_pow(a, 2)
c = jax.numpy.sum(b, axis=(0,))
return c
StatefulFunction: the engine behind the transforms#
StatefulFunction is the lower-level wrapper that every state-aware transform builds on. It
traces a function once and then answers precise questions about it: which states it reads,
which it writes, and what its jaxpr looks like. Construct it with ir_optimizations to apply
optimization passes automatically during tracing.
sf = T.StatefulFunction(stateful, ir_optimizations=['dce', 'cse'])
sf.make_jaxpr(jnp.array(1.0))
print('read states :', len(sf.get_read_states(jnp.array(1.0))))
print('write states:', len(sf.get_write_states(jnp.array(1.0))))
# Execute through the traced jaxpr, automatically threading state in and out.
counter.value = jnp.array(0.0)
print('jaxpr call result:', float(sf.jaxpr_call_auto(jnp.array(3.0))))
read states : 0
write states: 1
jaxpr call result: 6.0
Summary#
make_jaxprexposes the state-aware IR: reads become inputs, writes become outputs.optimize_jaxprrunsdce,cse,constant_fold,algebraic_simplification, andcopy_propagationpasses to shrink a jaxpr.jaxpr_to_python_code/fn_to_python_coderegenerate readable Python from the IR.StatefulFunctionis the underlying primitive: it reports read/write states and can apply IR optimizations as it traces.
See also#
JIT and compilation — how tracing and caching drive compilation.
Debugging — inspecting values rather than structure.