brainstate.transform.dead_code_elimination

brainstate.transform.dead_code_elimination#

brainstate.transform.dead_code_elimination(jaxpr)[source]#

Remove equations whose outputs are not used (dead code elimination).

This optimization performs a backward pass to identify which variables are actually used, then removes equations that produce unused outputs. This reduces the number of computations and can improve performance.

Parameters:

jaxpr (Jaxpr) – The input Jaxpr to optimize.

Returns:

A new Jaxpr with dead code removed. All input and output variables are preserved.

Return type:

Jaxpr

Notes

This optimization preserves all input and output variables to maintain the function interface. Only internal dead computations are eliminated.

The algorithm uses a two-phase approach: 1. Backward pass: Mark all variables that are transitively used 2. Forward pass: Keep only equations that produce marked variables

Examples

>>> # Given a jaxpr with unused intermediate computations
>>> # Before: a = x + 1; b = x * 2; y = x + 2  (a and b unused)
>>> # After:  y = x + 2
>>> optimized_jaxpr = dead_code_elimination(original_jaxpr)