brainstate.transform.common_subexpression_elimination#
- brainstate.transform.common_subexpression_elimination(jaxpr)[source]#
Eliminate redundant computations by reusing results (CSE).
Common Subexpression Elimination identifies equations that perform the same operation with identical inputs and reuses the result instead of recomputing. This reduces redundant computations and memory usage.
- Parameters:
jaxpr (
Jaxpr) – The input Jaxpr to optimize.- Returns:
A new Jaxpr with common subexpressions eliminated. All input and output variables are preserved.
- Return type:
Jaxpr
Notes
This optimization preserves all input and output variables. When output variables are mapped to other variables due to CSE, identity equations (using
convert_element_typewith the same dtype) are added to maintain the correct interface.Two equations are considered identical if they have: - The same primitive operation - The same input variables (by identity) - The same parameters
Examples
>>> # Given a jaxpr with duplicate computations >>> # Before: a = x + y; b = x * 2; c = x + y (c duplicates a) >>> # After: a = x + y; b = x * 2; c = a >>> optimized_jaxpr = common_subexpression_elimination(original_jaxpr)