Source code for brainstate.transform._ir_optim

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from collections.abc import MutableSet
from typing import Union, Sequence

import jax
import numpy as np
from jax import lax
from jax._src.core import JaxprEqnContext
from jax.extend import source_info_util

from brainstate._compatible_import import (Literal, Var, Jaxpr, ClosedJaxpr, JaxprEqn)

__all__ = [
    'constant_fold',
    'dead_code_elimination',
    'common_subexpression_elimination',
    'copy_propagation',
    'algebraic_simplification',
    'optimize_jaxpr',
]


def _fallback_source_info(eqns: Sequence[JaxprEqn]) -> source_info_util.SourceInfo:
    if len(eqns) > 0:
        source_info = eqns[-1].source_info
        if source_info is not None:
            return source_info
    return source_info_util.new_source_info()


def _assign_literal(
    literal: Literal,
    outvar: Var,
    source_info: source_info_util.SourceInfo
) -> JaxprEqn:
    eqn = JaxprEqn(
        [literal],
        [outvar],
        lax.convert_element_type_p,
        {'new_dtype': outvar.aval.dtype, 'weak_type': False, 'sharding': None},
        set(),
        source_info,
        JaxprEqnContext(None, True),
    )
    return eqn


def _preserve_invars_outvars(result: Jaxpr, jaxpr: Jaxpr):
    eqns = list(result.eqns)
    for v1, v2 in zip(result.outvars, jaxpr.outvars):
        if isinstance(v1, Literal) and isinstance(v2, Var):
            eqns.append(_assign_literal(v1, v2, _fallback_source_info(eqns)))
    # Ensure invars and outvars are preserved
    return result.replace(eqns=eqns, invars=jaxpr.invars, outvars=jaxpr.outvars)


class IdentitySet(MutableSet):
    """
    Set that compares objects by identity instead of equality.

    This is a mutable set implementation that uses object identity (``id()``)
    for comparison rather than equality (``==``). It is useful for storing
    objects that are not hashable or that must be compared by identity.

    Notes
    -----
    This class does not support the ``__hash__`` method and therefore cannot
    be used as a dictionary key or as an element of another set.

    Examples
    --------
    >>> s = IdentitySet()
    >>> a = [1, 2, 3]
    >>> b = [1, 2, 3]
    >>> s.add(a)
    >>> a in s
    True
    >>> b in s  # Different object, even though equal
    False
    """
    __module__ = 'brainstate.transform'

    def __init__(self, iterable=None):
        self._data = {}
        if iterable is not None:
            self.update(iterable)

    def __contains__(self, value):
        return id(value) in self._data

    def __iter__(self):
        return iter(self._data.values())

    def __len__(self):
        return len(self._data)

    def add(self, value):
        self._data[id(value)] = value

    def discard(self, value):
        self._data.pop(id(value), None)

    def update(self, iterable):
        """
        Add all elements from iterable to the set.

        Parameters
        ----------
        iterable : iterable
            An iterable of items to add to the set.
        """
        for item in iterable:
            self.add(item)

    def __repr__(self):
        return f"IdentitySet({list(repr(x) for x in self._data.values())})"

    def __str__(self):
        return f"IdentitySet({list(str(x) for x in self._data.values())})"


_constant_fold_blacklist = {'broadcast_in_dim', 'broadcast'}


def _partial_eval_jaxpr(jaxpr, env):
    env = env.copy()
    new_eqns = []

    def read(var):
        if isinstance(var, Literal):
            return var.val
        else:
            return env.get(var, None)

    def read_or_self(var):
        out = read(var)
        if out is None:
            return var
        elif isinstance(out, Var):
            return out
        elif isinstance(out, Literal):
            return Literal(out.val, var.aval)
        else:
            assert not isinstance(out, Jaxpr)
            return Literal(out, var.aval)

    for eqn in jaxpr.eqns:
        vals = [read(var) for var in eqn.invars]
        if eqn.primitive.name in _constant_fold_blacklist:
            new_eqns.append(eqn)
        elif all(val is not None for val in vals):
            # go ahead and eval it
            out = _eval_eqn(eqn, vals)

            # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values
            if isinstance(out, Jaxpr):
                # we need to inline this
                new_eqns.extend(out.eqns)
                out = out.outvars
            elif not isinstance(out, tuple) and not isinstance(out, list):
                out = (out,)

            for var, val in zip(eqn.outvars, out):
                assert not isinstance(val, Jaxpr)
                if isinstance(val, Literal):
                    env[var] = val.val
                else:
                    env[var] = val
        else:
            new_eqns.append(eqn)

    # now that we've eval everything, inline all the constants
    out_eqns = []
    for eqn in new_eqns:
        eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars))
        out_eqns.append(eqn)

    invars_still_used = IdentitySet()
    for eqn in out_eqns:
        for var in eqn.invars:
            if not isinstance(var, Literal):
                invars_still_used.add(var)

    invars = tuple(var for var in jaxpr.invars if var in invars_still_used)

    # sub in any constants for outvars
    outvars = tuple(read_or_self(var) for var in jaxpr.outvars)

    return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars)


def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jax.Array]:
    if eqn.primitive.name == "closed_call":
        assert eqn.primitive.call_primitive
        assert not eqn.primitive.map_primitive
        out = _partial_eval_jaxpr(
            eqn.params['call_jaxpr'].jaxpr,
            {
                var: val
                for var, val in
                zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)
            }
        )

    elif eqn.primitive.name == "scan":
        out = eqn.primitive.bind(*vals, **eqn.params)

    else:
        out = eqn.primitive.bind(*vals, **eqn.params)
    return out


[docs] def constant_fold(jaxpr: Jaxpr) -> Jaxpr: """ Perform constant folding optimization on a Jaxpr. This optimization evaluates all operations with constant inputs at compile time, replacing them with their computed constant values. This reduces runtime computation and can enable further optimizations. Parameters ---------- jaxpr : Jaxpr The input Jaxpr to optimize. Returns ------- Jaxpr A new Jaxpr with constant expressions evaluated. The input and output variables are preserved. Notes ----- This optimization preserves the input and output variables of the jaxpr, only modifying the internal computation. Some primitives like 'broadcast_in_dim' and 'broadcast' are blacklisted and won't be folded. Examples -------- >>> # Given a jaxpr that computes: y = x + (2 + 3) >>> # After constant folding: y = x + 5 >>> optimized_jaxpr = constant_fold(original_jaxpr) """ result = _partial_eval_jaxpr(jaxpr, {}) return _preserve_invars_outvars(result, jaxpr)
[docs] def dead_code_elimination(jaxpr: Jaxpr) -> Jaxpr: """ Remove equations whose outputs are not used (dead code elimination). This optimization performs a backward pass to identify which variables are actually used, then removes equations that produce unused outputs. This reduces the number of computations and can improve performance. Parameters ---------- jaxpr : Jaxpr The input Jaxpr to optimize. Returns ------- Jaxpr A new Jaxpr with dead code removed. All input and output variables are preserved. Notes ----- This optimization preserves all input and output variables to maintain the function interface. Only internal dead computations are eliminated. The algorithm uses a two-phase approach: 1. Backward pass: Mark all variables that are transitively used 2. Forward pass: Keep only equations that produce marked variables Examples -------- >>> # Given a jaxpr with unused intermediate computations >>> # Before: a = x + 1; b = x * 2; y = x + 2 (a and b unused) >>> # After: y = x + 2 >>> optimized_jaxpr = dead_code_elimination(original_jaxpr) """ # Mark all variables that are used (starting from outputs and ALL inputs) # We must keep all invars even if they appear unused, as they define the interface used_vars = IdentitySet(jaxpr.outvars) used_vars.update(jaxpr.invars) # Backward pass: mark variables as used if they're inputs to used equations # We need to iterate until convergence changed = True while changed: changed = False for eqn in reversed(jaxpr.eqns): # If any output is used, all inputs must be kept if any(outvar in used_vars for outvar in eqn.outvars): for invar in eqn.invars: if invar not in used_vars and not isinstance(invar, Literal): used_vars.add(invar) changed = True # Forward pass: keep only equations that produce used outputs new_eqns = [] for eqn in jaxpr.eqns: if any(outvar in used_vars for outvar in eqn.outvars): new_eqns.append(eqn) # Keep all input and output variables unchanged return jaxpr.replace(eqns=new_eqns, invars=jaxpr.invars, outvars=jaxpr.outvars)
[docs] def common_subexpression_elimination(jaxpr: Jaxpr) -> Jaxpr: """ Eliminate redundant computations by reusing results (CSE). Common Subexpression Elimination identifies equations that perform the same operation with identical inputs and reuses the result instead of recomputing. This reduces redundant computations and memory usage. Parameters ---------- jaxpr : Jaxpr The input Jaxpr to optimize. Returns ------- Jaxpr A new Jaxpr with common subexpressions eliminated. All input and output variables are preserved. Notes ----- This optimization preserves all input and output variables. When output variables are mapped to other variables due to CSE, identity equations (using ``convert_element_type`` with the same dtype) are added to maintain the correct interface. Two equations are considered identical if they have: - The same primitive operation - The same input variables (by identity) - The same parameters Examples -------- >>> # Given a jaxpr with duplicate computations >>> # Before: a = x + y; b = x * 2; c = x + y (c duplicates a) >>> # After: a = x + y; b = x * 2; c = a >>> optimized_jaxpr = common_subexpression_elimination(original_jaxpr) """ # Map from (primitive, invars, params) to output variables expr_cache = {} # Map from old variables to their replacements var_map = {} def get_var(var): """Get the canonical variable, following replacements.""" if isinstance(var, Literal): return var return var_map.get(var, var) def make_key(eqn): """Create a hashable key for an equation.""" # Use identity of variables for comparison invars_ids = tuple(id(get_var(v)) for v in eqn.invars) # Create a hashable representation of params param_items = tuple(sorted(eqn.params.items())) return (eqn.primitive.name, invars_ids, param_items) new_eqns = [] for eqn in jaxpr.eqns: # Update invars to use canonical variables canonical_invars = tuple(get_var(v) for v in eqn.invars) eqn = eqn.replace(invars=canonical_invars) # Check if we've seen this computation before key = make_key(eqn) if key in expr_cache and len(eqn.outvars) == len(expr_cache[key]): # Reuse previous result prev_outvars = expr_cache[key] for old_var, new_var in zip(eqn.outvars, prev_outvars): var_map[old_var] = new_var else: # This is a new computation, keep it new_eqns.append(eqn) expr_cache[key] = eqn.outvars # For outvars that have been replaced, add identity equations to preserve the interface final_eqns = new_eqns[:] outvars_need_identity = [] for outvar in jaxpr.outvars: canonical = get_var(outvar) if id(canonical) != id(outvar): outvars_need_identity.append((outvar, canonical)) # Add identity equations if needed if outvars_need_identity: # Import the identity primitive from jax from jax._src.core import JaxprEqnContext default_ctx = JaxprEqnContext(None, True) for outvar, canonical in outvars_need_identity: # Create an identity equation: outvar = identity(canonical) # Use convert_element_type as identity (same type) eqn = JaxprEqn([canonical], [outvar], lax.convert_element_type_p, {'new_dtype': outvar.aval.dtype, 'weak_type': False}, set(), _fallback_source_info(new_eqns), default_ctx) final_eqns.append(eqn) # Keep original outvars and invars return jaxpr.replace(eqns=final_eqns, outvars=jaxpr.outvars, invars=jaxpr.invars, debug_info=None)
[docs] def copy_propagation(jaxpr: Jaxpr) -> Jaxpr: """ Eliminate unnecessary copy operations by propagating original variables. When a variable is simply copied or renamed via identity operations (copy, device_put, or redundant convert_element_type), this optimization propagates the original variable forward, eliminating the copy operation. Parameters ---------- jaxpr : Jaxpr The input Jaxpr to optimize. Returns ------- Jaxpr A new Jaxpr with copies propagated. All input and output variables are preserved. Notes ----- This optimization preserves all input and output variables. Copy operations that produce output variables are kept to maintain the correct interface. The following operations are considered identity operations: - ``copy``: Always an identity - ``device_put``: Always an identity - ``convert_element_type``: Only when the input and output dtypes match Examples -------- >>> # Given a jaxpr with unnecessary copies >>> # Before: a = copy(x); b = a + 1; c = copy(b) >>> # After: b = x + 1; c = copy(b) >>> optimized_jaxpr = copy_propagation(original_jaxpr) """ # Map from variables to their canonical representatives var_map = {} # Track which outvars are identity operations that can be safely removed identity_outvars = set() def get_canonical(var): """Follow the chain of copies to find the canonical variable.""" if isinstance(var, Literal): return var original = var seen = set() while var in var_map and id(var) not in seen: seen.add(id(var)) var = var_map[var] return var new_eqns = [] for eqn in jaxpr.eqns: # Replace input variables with their canonical versions new_invars = tuple(get_canonical(v) for v in eqn.invars) # Check for identity/copy operations is_identity = False if eqn.primitive.name in ('copy', 'device_put', 'convert_element_type'): # These are potential identity operations if len(new_invars) == 1 and len(eqn.outvars) == 1: invar = new_invars[0] outvar = eqn.outvars[0] # For convert_element_type, check if types match if eqn.primitive.name == 'convert_element_type': if hasattr(invar, 'aval') and hasattr(outvar, 'aval'): if invar.aval.dtype == eqn.params.get('new_dtype'): is_identity = True else: is_identity = True if is_identity: # Only eliminate if outvar is not in the original outvars if outvar not in jaxpr.outvars: var_map[outvar] = invar else: # Keep the identity equation if it's an output variable is_identity = False if not is_identity: # Keep the equation with updated invars eqn = eqn.replace(invars=new_invars) new_eqns.append(eqn) # Update outvars, but keep them as-is since we preserved identity ops for them # Apply canonical mapping only to internal references new_outvars = jaxpr.outvars # Keep all input and output variables unchanged return jaxpr.replace(eqns=new_eqns, invars=jaxpr.invars, outvars=new_outvars)
[docs] def algebraic_simplification(jaxpr: Jaxpr) -> Jaxpr: """ Apply algebraic identities to simplify arithmetic operations. This optimization recognizes and applies common algebraic identities to simplify operations, reducing computational complexity and enabling further optimizations. Parameters ---------- jaxpr : Jaxpr The input Jaxpr to optimize. Returns ------- Jaxpr A new Jaxpr with algebraic simplifications applied. All input and output variables are preserved. Notes ----- This optimization preserves all input and output variables. When output variables are simplified, identity equations are added to maintain the correct interface. The following algebraic identities are recognized: Addition: - ``0 + x = x`` - ``x + 0 = x`` Subtraction: - ``x - 0 = x`` - ``x - x = 0`` Multiplication: - ``0 * x = 0`` - ``x * 0 = 0`` - ``1 * x = x`` - ``x * 1 = x`` Division: - ``x / 1 = x`` - ``0 / x = 0`` (assuming x != 0) Examples -------- >>> # Given a jaxpr with algebraic simplifications >>> # Before: a = x + 0; b = a * 1; c = b - 0 >>> # After: a = x; b = a; c = b >>> optimized_jaxpr = algebraic_simplification(original_jaxpr) """ # Map from variables to their replacements (for eliminated operations) var_map = {} def get_var(var): """Get the canonical variable.""" if isinstance(var, Literal): return var return var_map.get(var, var) def is_constant_value(var, value): """Check if a variable is a literal with a specific value.""" if not isinstance(var, Literal): return False val = var.val try: # Handle scalar and array constants if isinstance(val, (int, float, complex)): return val == value elif hasattr(val, '__array__'): arr = np.asarray(val) return arr.shape == () and arr.item() == value except: pass return False def is_zero(var): return is_constant_value(var, 0) def is_one(var): return is_constant_value(var, 1) def make_literal(value, aval): """Create a literal with the given value and abstract value.""" return Literal(value, aval) new_eqns = [] for eqn in jaxpr.eqns: # Update invars to use canonical variables canonical_invars = tuple(get_var(v) for v in eqn.invars) simplified = False if len(canonical_invars) >= 2 and len(eqn.outvars) == 1: lhs, rhs = canonical_invars[0], canonical_invars[1] outvar = eqn.outvars[0] # Addition simplifications if eqn.primitive.name == 'add': if is_zero(lhs): # 0 + x = x var_map[outvar] = rhs simplified = True elif is_zero(rhs): # x + 0 = x var_map[outvar] = lhs simplified = True # Subtraction simplifications elif eqn.primitive.name == 'sub': if is_zero(rhs): # x - 0 = x var_map[outvar] = lhs simplified = True elif id(lhs) == id(rhs): # x - x = 0 var_map[outvar] = make_literal(0, outvar.aval) simplified = True # Multiplication simplifications elif eqn.primitive.name == 'mul': if is_zero(lhs) or is_zero(rhs): # 0 * x = 0 or x * 0 = 0 var_map[outvar] = make_literal(0, outvar.aval) simplified = True elif is_one(lhs): # 1 * x = x var_map[outvar] = rhs simplified = True elif is_one(rhs): # x * 1 = x var_map[outvar] = lhs simplified = True # Division simplifications elif eqn.primitive.name == 'div': if is_one(rhs): # x / 1 = x var_map[outvar] = lhs simplified = True elif is_zero(lhs): # 0 / x = 0 (assuming x != 0) var_map[outvar] = make_literal(0, outvar.aval) simplified = True if not simplified: # Keep the equation with updated invars eqn = eqn.replace(invars=canonical_invars) new_eqns.append(eqn) # For outvars that have been replaced, add identity equations to preserve the interface final_eqns = new_eqns[:] outvars_need_identity = [] for outvar in jaxpr.outvars: canonical = get_var(outvar) if id(canonical) != id(outvar): outvars_need_identity.append((outvar, canonical)) # Add identity equations if needed if outvars_need_identity: for outvar, canonical in outvars_need_identity: # Create an identity equation: outvar = identity(canonical) final_eqns.append(_assign_literal(canonical, outvar, _fallback_source_info(final_eqns))) # Keep original outvars and invars return jaxpr.replace(eqns=final_eqns, outvars=jaxpr.outvars, invars=jaxpr.invars)
[docs] def optimize_jaxpr( jaxpr: Jaxpr | ClosedJaxpr, max_iterations: int = 3, optimizations: Sequence[str] | None = None, verbose: bool = False, ) -> Jaxpr | ClosedJaxpr: """ Apply multiple optimization passes to a Jaxpr. This function applies a sequence of optimizations in multiple iterations until convergence or the maximum number of iterations is reached. The optimizations work together to simplify the computation graph while preserving the function's semantics and interface. Parameters ---------- jaxpr : Jaxpr or ClosedJaxpr The input Jaxpr or ClosedJaxpr to optimize. max_iterations : int, optional Maximum number of optimization passes. Default is 3. optimizations : sequence of str, optional List of optimization names to apply in order. If None, applies all optimizations in the recommended order: constant_fold, algebraic_simplification, copy_propagation, cse, dce. Use a custom list to control which optimizations run and in what order. verbose : bool, optional If True, print detailed optimization progress information including equation counts and reduction statistics. Default is False. Returns ------- Jaxpr or ClosedJaxpr An optimized Jaxpr or ClosedJaxpr (same type as input) with reduced equation count and improved efficiency. Raises ------ TypeError If the input is not a Jaxpr or ClosedJaxpr. ValueError If any optimization name in ``optimizations`` is invalid. RuntimeError If the input or output variables change during optimization (indicates a bug in the optimization passes). Notes ----- Available optimizations: - **constant_fold**: Evaluate constant expressions at compile time - **algebraic_simplification**: Apply algebraic identities (x+0=x, x*1=x, etc.) - **copy_propagation**: Eliminate unnecessary copy operations - **cse**: Common subexpression elimination (reuse identical computations) - **dce**: Dead code elimination (remove unused equations) The optimization process iterates until: 1. No more equations can be eliminated (convergence), or 2. The maximum number of iterations is reached All optimizations preserve the function interface (input and output variables) while optimizing the internal computation graph. Examples -------- Apply all default optimizations: .. code-block:: python >>> optimized = optimize_jaxpr(jaxpr) Use more iterations for aggressive optimization: .. code-block:: python >>> optimized = optimize_jaxpr(jaxpr, max_iterations=5) Run only specific optimizations: .. code-block:: python >>> optimized = optimize_jaxpr(jaxpr, optimizations=['constant_fold', 'dce']) Enable verbose output to see optimization progress: .. code-block:: python >>> optimized = optimize_jaxpr(jaxpr, verbose=True) Starting optimization with 50 equations Optimization sequence: constant_fold -> algebraic_simplification -> ... Max iterations: 3 ------------------------------------------------------------ Iteration 1: constant_fold: 50 -> 45 equations (-5) algebraic_simplification: 45 -> 42 equations (-3) dce: 42 -> 38 equations (-4) Converged after 2 iteration(s) ------------------------------------------------------------ Optimization complete: Initial equations: 50 Final equations: 38 Reduction: 12 (24.0%) Custom optimization pipeline: .. code-block:: python >>> # First fold constants, then eliminate dead code >>> stage1 = optimize_jaxpr(jaxpr, optimizations=['constant_fold', 'dce']) >>> # Then apply CSE and more DCE >>> stage2 = optimize_jaxpr(stage1, optimizations=['cse', 'dce']) """ if optimizations is None: optimizations = [] elif isinstance(optimizations, str): if optimizations == 'all': optimizations = [ 'constant_fold', 'algebraic_simplification', 'copy_propagation', 'cse', 'dce', ] else: optimizations = [optimizations] # Parse input if isinstance(jaxpr, Jaxpr): closed_jaxpr = None elif isinstance(jaxpr, ClosedJaxpr): closed_jaxpr = jaxpr jaxpr = jaxpr.jaxpr else: raise TypeError(f'Expected Jaxpr or ClosedJaxpr, got {type(jaxpr)}') # Store original interface invars_before = tuple(jaxpr.invars) outvars_before = tuple(jaxpr.outvars) initial_eqns = len(jaxpr.eqns) # Define available optimizations _OPTIMIZATION_MAP = { 'constant_fold': constant_fold, 'algebraic_simplification': algebraic_simplification, 'copy_propagation': copy_propagation, 'cse': common_subexpression_elimination, 'dce': dead_code_elimination, } # Default optimization sequence if optimizations is None: optimizations = [ 'constant_fold', 'algebraic_simplification', 'copy_propagation', 'cse', 'dce', ] # Validate optimization names invalid_opts = set(optimizations) - set(_OPTIMIZATION_MAP.keys()) if invalid_opts: available = ', '.join(sorted(_OPTIMIZATION_MAP.keys())) raise ValueError( f"Invalid optimization(s): {', '.join(invalid_opts)}. " f"Available optimizations: {available}" ) if verbose: print(f"Starting optimization with {initial_eqns} equations") print(f"Optimization sequence: {' -> '.join(optimizations)}") print(f"Max iterations: {max_iterations}") print("-" * 60) # Apply optimization iterations for iteration in range(max_iterations): prev_num_eqns = len(jaxpr.eqns) if verbose: print(f"\nIteration {iteration + 1}:") # Apply each optimization in sequence for opt_name in optimizations: opt_func = _OPTIMIZATION_MAP[opt_name] prev_eqns = len(jaxpr.eqns) jaxpr = opt_func(jaxpr) current_eqns = len(jaxpr.eqns) if verbose and current_eqns != prev_eqns: reduction = prev_eqns - current_eqns print(f" {opt_name}: {prev_eqns} -> {current_eqns} equations " f"({reduction:+d})") # Check for convergence if len(jaxpr.eqns) == prev_num_eqns: if verbose: print(f"\nConverged after {iteration + 1} iteration(s)") break else: if verbose: print(f"\nReached max iterations ({max_iterations})") # Final statistics final_eqns = len(jaxpr.eqns) if verbose: print("-" * 60) print(f"Optimization complete:") print(f" Initial equations: {initial_eqns}") print(f" Final equations: {final_eqns}") print(f" Reduction: {initial_eqns - final_eqns} " f"({100 * (initial_eqns - final_eqns) / initial_eqns:.1f}%)") # Validate that interface is preserved invars_after = tuple(jaxpr.invars) outvars_after = tuple(jaxpr.outvars) if invars_before != invars_after: raise RuntimeError( f'Input variables changed during optimization. ' f'Before: {len(invars_before)}, After: {len(invars_after)}' ) if outvars_before != outvars_after: raise RuntimeError( f'Output variables changed during optimization. ' f'Before: {len(outvars_before)}, After: {len(outvars_after)}' ) # Restore ClosedJaxpr if needed if closed_jaxpr is not None: jaxpr = ClosedJaxpr(jaxpr, closed_jaxpr.consts) return jaxpr