brainstate.transform.constant_fold

Contents

brainstate.transform.constant_fold#

brainstate.transform.constant_fold(jaxpr)[source]#

Perform constant folding optimization on a Jaxpr.

This optimization evaluates all operations with constant inputs at compile time, replacing them with their computed constant values. This reduces runtime computation and can enable further optimizations.

Parameters:

jaxpr (Jaxpr) – The input Jaxpr to optimize.

Returns:

A new Jaxpr with constant expressions evaluated. The input and output variables are preserved.

Return type:

Jaxpr

Notes

This optimization preserves the input and output variables of the jaxpr, only modifying the internal computation. Some primitives like ‘broadcast_in_dim’ and ‘broadcast’ are blacklisted and won’t be folded.

Examples

>>> # Given a jaxpr that computes: y = x + (2 + 3)
>>> # After constant folding: y = x + 5
>>> optimized_jaxpr = constant_fold(original_jaxpr)