brainstate.transform.algebraic_simplification

brainstate.transform.algebraic_simplification#

brainstate.transform.algebraic_simplification(jaxpr)[source]#

Apply algebraic identities to simplify arithmetic operations.

This optimization recognizes and applies common algebraic identities to simplify operations, reducing computational complexity and enabling further optimizations.

Parameters:

jaxpr (Jaxpr) – The input Jaxpr to optimize.

Returns:

A new Jaxpr with algebraic simplifications applied. All input and output variables are preserved.

Return type:

Jaxpr

Notes

This optimization preserves all input and output variables. When output variables are simplified, identity equations are added to maintain the correct interface.

The following algebraic identities are recognized:

Addition:
  • 0 + x = x

  • x + 0 = x

Subtraction:
  • x - 0 = x

  • x - x = 0

Multiplication:
  • 0 * x = 0

  • x * 0 = 0

  • 1 * x = x

  • x * 1 = x

Division:
  • x / 1 = x

  • 0 / x = 0 (assuming x != 0)

Examples

>>> # Given a jaxpr with algebraic simplifications
>>> # Before: a = x + 0; b = a * 1; c = b - 0
>>> # After:  a = x; b = a; c = b
>>> optimized_jaxpr = algebraic_simplification(original_jaxpr)