brainstate.transform.common_subexpression_elimination

brainstate.transform.common_subexpression_elimination#

brainstate.transform.common_subexpression_elimination(jaxpr)[source]#

Eliminate redundant computations by reusing results (CSE).

Common Subexpression Elimination identifies equations that perform the same operation with identical inputs and reuses the result instead of recomputing. This reduces redundant computations and memory usage.

Parameters:

jaxpr (Jaxpr) – The input Jaxpr to optimize.

Returns:

A new Jaxpr with common subexpressions eliminated. All input and output variables are preserved.

Return type:

Jaxpr

Notes

This optimization preserves all input and output variables. When output variables are mapped to other variables due to CSE, identity equations (using convert_element_type with the same dtype) are added to maintain the correct interface.

Two equations are considered identical if they have: - The same primitive operation - The same input variables (by identity) - The same parameters

Examples

>>> # Given a jaxpr with duplicate computations
>>> # Before: a = x + y; b = x * 2; c = x + y  (c duplicates a)
>>> # After:  a = x + y; b = x * 2; c = a
>>> optimized_jaxpr = common_subexpression_elimination(original_jaxpr)