brainstate.transform.inline_jit

Contents

brainstate.transform.inline_jit#

brainstate.transform.inline_jit(jaxpr, should_expand=None)[source]#

Rewrite a jaxpr by expanding (inlining) jit equations that satisfy the given condition.

This function recursively traverses a jaxpr and expands (inlines) JIT-compiled function calls based on a user-provided predicate. Variables are carefully remapped to maintain correctness across scope boundaries.

Parameters:
  • jaxpr (Jaxpr | ClosedJaxpr) – The input jaxpr to rewrite. Can be either a Jaxpr or ClosedJaxpr.

  • should_expand (Callable[[JaxprEqn], bool] | None) – A predicate function that takes a JaxprEqn and returns True if the jit should be expanded. If None, all jit equations are expanded. The predicate can inspect equation parameters like call_jaxpr to make decisions based on the function’s complexity, size, or content.

Returns:

A new jaxpr with qualified jit equations expanded. The return type matches the input type (Jaxpr returns Jaxpr, ClosedJaxpr returns ClosedJaxpr).

Return type:

Jaxpr | ClosedJaxpr

Examples

>>> from jax import make_jaxpr
>>> import jax.numpy as jnp
>>> import jax
>>>
>>> @jax.jit
... def inner(x):
...     return x + 1
>>>
>>> def outer(x):
...     return inner(x) * 2
>>>
>>> jaxpr = make_jaxpr(outer)(1.0)
>>> expanded = inline_jit(jaxpr.jaxpr)  # Expands all jits
>>>
>>> # Conditional expansion - only expand small functions
>>> def expand_small(eqn):
...     call_jaxpr = eqn.params.get('call_jaxpr') or eqn.params.get('jaxpr')
...     return call_jaxpr and len(call_jaxpr.eqns) <= 5
>>> expanded = inline_jit(jaxpr.jaxpr, expand_small)