brainstate.transform.constant_fold#
- brainstate.transform.constant_fold(jaxpr)[source]#
Perform constant folding optimization on a Jaxpr.
This optimization evaluates all operations with constant inputs at compile time, replacing them with their computed constant values. This reduces runtime computation and can enable further optimizations.
- Parameters:
jaxpr (
Jaxpr) – The input Jaxpr to optimize.- Returns:
A new Jaxpr with constant expressions evaluated. The input and output variables are preserved.
- Return type:
Jaxpr
Notes
This optimization preserves the input and output variables of the jaxpr, only modifying the internal computation. Some primitives like ‘broadcast_in_dim’ and ‘broadcast’ are blacklisted and won’t be folded.
Examples
>>> # Given a jaxpr that computes: y = x + (2 + 3) >>> # After constant folding: y = x + 5 >>> optimized_jaxpr = constant_fold(original_jaxpr)