brainstate.transform.dead_code_elimination#
- brainstate.transform.dead_code_elimination(jaxpr)[source]#
Remove equations whose outputs are not used (dead code elimination).
This optimization performs a backward pass to identify which variables are actually used, then removes equations that produce unused outputs. This reduces the number of computations and can improve performance.
- Parameters:
jaxpr (
Jaxpr) – The input Jaxpr to optimize.- Returns:
A new Jaxpr with dead code removed. All input and output variables are preserved.
- Return type:
Jaxpr
Notes
This optimization preserves all input and output variables to maintain the function interface. Only internal dead computations are eliminated.
The algorithm uses a two-phase approach: 1. Backward pass: Mark all variables that are transitively used 2. Forward pass: Keep only equations that produce marked variables
Examples
>>> # Given a jaxpr with unused intermediate computations >>> # Before: a = x + 1; b = x * 2; y = x + 2 (a and b unused) >>> # After: y = x + 2 >>> optimized_jaxpr = dead_code_elimination(original_jaxpr)