brainstate.transform.inline_jit#
- brainstate.transform.inline_jit(jaxpr, should_expand=None)[source]#
Rewrite a jaxpr by expanding (inlining) jit equations that satisfy the given condition.
This function recursively traverses a jaxpr and expands (inlines) JIT-compiled function calls based on a user-provided predicate. Variables are carefully remapped to maintain correctness across scope boundaries.
- Parameters:
jaxpr (
Jaxpr|ClosedJaxpr) – The input jaxpr to rewrite. Can be either a Jaxpr or ClosedJaxpr.should_expand (
Callable[[JaxprEqn],bool] |None) – A predicate function that takes a JaxprEqn and returns True if the jit should be expanded. If None, all jit equations are expanded. The predicate can inspect equation parameters like call_jaxpr to make decisions based on the function’s complexity, size, or content.
- Returns:
A new jaxpr with qualified jit equations expanded. The return type matches the input type (Jaxpr returns Jaxpr, ClosedJaxpr returns ClosedJaxpr).
- Return type:
Jaxpr|ClosedJaxpr
Examples
>>> from jax import make_jaxpr >>> import jax.numpy as jnp >>> import jax >>> >>> @jax.jit ... def inner(x): ... return x + 1 >>> >>> def outer(x): ... return inner(x) * 2 >>> >>> jaxpr = make_jaxpr(outer)(1.0) >>> expanded = inline_jit(jaxpr.jaxpr) # Expands all jits >>> >>> # Conditional expansion - only expand small functions >>> def expand_small(eqn): ... call_jaxpr = eqn.params.get('call_jaxpr') or eqn.params.get('jaxpr') ... return call_jaxpr and len(call_jaxpr.eqns) <= 5 >>> expanded = inline_jit(jaxpr.jaxpr, expand_small)