brainstate.transform.algebraic_simplification#
- brainstate.transform.algebraic_simplification(jaxpr)[source]#
Apply algebraic identities to simplify arithmetic operations.
This optimization recognizes and applies common algebraic identities to simplify operations, reducing computational complexity and enabling further optimizations.
- Parameters:
jaxpr (
Jaxpr) – The input Jaxpr to optimize.- Returns:
A new Jaxpr with algebraic simplifications applied. All input and output variables are preserved.
- Return type:
Jaxpr
Notes
This optimization preserves all input and output variables. When output variables are simplified, identity equations are added to maintain the correct interface.
The following algebraic identities are recognized:
- Addition:
0 + x = xx + 0 = x
- Subtraction:
x - 0 = xx - x = 0
- Multiplication:
0 * x = 0x * 0 = 01 * x = xx * 1 = x
- Division:
x / 1 = x0 / x = 0(assuming x != 0)
Examples
>>> # Given a jaxpr with algebraic simplifications >>> # Before: a = x + 0; b = a * 1; c = b - 0 >>> # After: a = x; b = a; c = b >>> optimized_jaxpr = algebraic_simplification(original_jaxpr)