brainstate.transform.optimize_jaxpr#
- brainstate.transform.optimize_jaxpr(jaxpr, max_iterations=3, optimizations=None, verbose=False)[source]#
Apply multiple optimization passes to a Jaxpr.
This function applies a sequence of optimizations in multiple iterations until convergence or the maximum number of iterations is reached. The optimizations work together to simplify the computation graph while preserving the function’s semantics and interface.
- Parameters:
jaxpr (
Jaxpr|ClosedJaxpr) – The input Jaxpr or ClosedJaxpr to optimize.max_iterations (
int) – Maximum number of optimization passes. Default is 3.optimizations (
Sequence[str] |None) – List of optimization names to apply in order. If None, applies all optimizations in the recommended order: constant_fold, algebraic_simplification, copy_propagation, cse, dce. Use a custom list to control which optimizations run and in what order.verbose (
bool) – If True, print detailed optimization progress information including equation counts and reduction statistics. Default is False.
- Returns:
An optimized Jaxpr or ClosedJaxpr (same type as input) with reduced equation count and improved efficiency.
- Return type:
Jaxpr|ClosedJaxpr- Raises:
TypeError – If the input is not a Jaxpr or ClosedJaxpr.
ValueError – If any optimization name in
optimizationsis invalid.RuntimeError – If the input or output variables change during optimization (indicates a bug in the optimization passes).
Notes
Available optimizations:
constant_fold: Evaluate constant expressions at compile time
algebraic_simplification: Apply algebraic identities (x+0=x, x*1=x, etc.)
copy_propagation: Eliminate unnecessary copy operations
cse: Common subexpression elimination (reuse identical computations)
dce: Dead code elimination (remove unused equations)
The optimization process iterates until:
No more equations can be eliminated (convergence), or
The maximum number of iterations is reached
All optimizations preserve the function interface (input and output variables) while optimizing the internal computation graph.
Examples
Apply all default optimizations:
>>> optimized = optimize_jaxpr(jaxpr)
Use more iterations for aggressive optimization:
>>> optimized = optimize_jaxpr(jaxpr, max_iterations=5)
Run only specific optimizations:
>>> optimized = optimize_jaxpr(jaxpr, optimizations=['constant_fold', 'dce'])
Enable verbose output to see optimization progress:
>>> optimized = optimize_jaxpr(jaxpr, verbose=True) Starting optimization with 50 equations Optimization sequence: constant_fold -> algebraic_simplification -> ... Max iterations: 3 ------------------------------------------------------------ Iteration 1: constant_fold: 50 -> 45 equations (-5) algebraic_simplification: 45 -> 42 equations (-3) dce: 42 -> 38 equations (-4) Converged after 2 iteration(s) ------------------------------------------------------------ Optimization complete: Initial equations: 50 Final equations: 38 Reduction: 12 (24.0%)
Custom optimization pipeline:
>>> # First fold constants, then eliminate dead code >>> stage1 = optimize_jaxpr(jaxpr, optimizations=['constant_fold', 'dce']) >>> # Then apply CSE and more DCE >>> stage2 = optimize_jaxpr(stage1, optimizations=['cse', 'dce'])