brainstate.transform.optimize_jaxpr

Contents

brainstate.transform.optimize_jaxpr#

brainstate.transform.optimize_jaxpr(jaxpr, max_iterations=3, optimizations=None, verbose=False)[source]#

Apply multiple optimization passes to a Jaxpr.

This function applies a sequence of optimizations in multiple iterations until convergence or the maximum number of iterations is reached. The optimizations work together to simplify the computation graph while preserving the function’s semantics and interface.

Parameters:
  • jaxpr (Jaxpr | ClosedJaxpr) – The input Jaxpr or ClosedJaxpr to optimize.

  • max_iterations (int) – Maximum number of optimization passes. Default is 3.

  • optimizations (Sequence[str] | None) – List of optimization names to apply in order. If None, applies all optimizations in the recommended order: constant_fold, algebraic_simplification, copy_propagation, cse, dce. Use a custom list to control which optimizations run and in what order.

  • verbose (bool) – If True, print detailed optimization progress information including equation counts and reduction statistics. Default is False.

Returns:

An optimized Jaxpr or ClosedJaxpr (same type as input) with reduced equation count and improved efficiency.

Return type:

Jaxpr | ClosedJaxpr

Raises:
  • TypeError – If the input is not a Jaxpr or ClosedJaxpr.

  • ValueError – If any optimization name in optimizations is invalid.

  • RuntimeError – If the input or output variables change during optimization (indicates a bug in the optimization passes).

Notes

Available optimizations:

  • constant_fold: Evaluate constant expressions at compile time

  • algebraic_simplification: Apply algebraic identities (x+0=x, x*1=x, etc.)

  • copy_propagation: Eliminate unnecessary copy operations

  • cse: Common subexpression elimination (reuse identical computations)

  • dce: Dead code elimination (remove unused equations)

The optimization process iterates until:

  1. No more equations can be eliminated (convergence), or

  2. The maximum number of iterations is reached

All optimizations preserve the function interface (input and output variables) while optimizing the internal computation graph.

Examples

Apply all default optimizations:

>>> optimized = optimize_jaxpr(jaxpr)

Use more iterations for aggressive optimization:

>>> optimized = optimize_jaxpr(jaxpr, max_iterations=5)

Run only specific optimizations:

>>> optimized = optimize_jaxpr(jaxpr, optimizations=['constant_fold', 'dce'])

Enable verbose output to see optimization progress:

>>> optimized = optimize_jaxpr(jaxpr, verbose=True)
Starting optimization with 50 equations
Optimization sequence: constant_fold -> algebraic_simplification -> ...
Max iterations: 3
------------------------------------------------------------

Iteration 1:
  constant_fold: 50 -> 45 equations (-5)
  algebraic_simplification: 45 -> 42 equations (-3)
  dce: 42 -> 38 equations (-4)

Converged after 2 iteration(s)
------------------------------------------------------------
Optimization complete:
  Initial equations: 50
  Final equations:   38
  Reduction:         12 (24.0%)

Custom optimization pipeline:

>>> # First fold constants, then eliminate dead code
>>> stage1 = optimize_jaxpr(jaxpr, optimizations=['constant_fold', 'dce'])
>>> # Then apply CSE and more DCE
>>> stage2 = optimize_jaxpr(stage1, optimizations=['cse', 'dce'])