Source code for brainstate.transform._ir_processing

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Sequence, Dict, List, Set
from collections import defaultdict
from jax.extend.core.primitives import dot_general_p, conv_general_dilated_p

from brainstate._compatible_import import is_jit_primitive, JaxprEqn, Jaxpr, ClosedJaxpr, Var, Literal
from brainstate._state import State


__all__ = [
    'eqns_to_closed_jaxpr',
    'eqns_to_jaxpr',
]


[docs] def eqns_to_jaxpr( eqns: Sequence[JaxprEqn], invars: Sequence[Var] = None, outvars: Sequence[Var] = None, constvars: Sequence[Var] = None, ) -> Jaxpr: """ Convert a sequence of JaxprEqn into a Jaxpr. Args: eqns: Sequence of Jaxpr equations to convert invars: Input variables. If None, will be inferred from equations outvars: Output variables. If None, will be inferred from equations constvars: Constant variables. If None, will be automatically extracted from equations Returns: Jaxpr: A Jaxpr object constructed from the equations """ # Collect all variables produced by equations produced_vars = set() for eqn in eqns: produced_vars.update(eqn.outvars) # Collect all variables used in equations (excluding Literals) used_vars_set = set() for eqn in eqns: for var in eqn.invars: if isinstance(var, Var): used_vars_set.add(var) # Infer invars if not provided if invars is None: # Variables that are used but not produced are potential invars or constvars invars = [] for eqn in eqns: for var in eqn.invars: if isinstance(var, Var): if var not in produced_vars and var not in invars: invars.append(var) else: invars = list(invars) # Infer constvars if not provided # Constvars are variables used in equations but not in invars or produced_vars if constvars is None: invars_set = set(invars) constvars = [] for var in used_vars_set: if var not in produced_vars and var not in invars_set: if var not in constvars: constvars.append(var) else: constvars = list(constvars) # Infer outvars if not provided if outvars is None: # Variables that are produced but not consumed (or only consumed) are outputs consumed_vars = set() for eqn in eqns: for var in eqn.invars: if isinstance(var, Var) and var in produced_vars: consumed_vars.add(var) outvars = list(produced_vars - consumed_vars) else: outvars = list(outvars) return Jaxpr( constvars=constvars, invars=invars, outvars=outvars, eqns=list(eqns), )
[docs] def eqns_to_closed_jaxpr( eqns: Sequence[JaxprEqn], invars: Sequence[Var] = None, outvars: Sequence[Var] = None, constvars: Sequence[Var] = None, consts: Sequence = None, ) -> ClosedJaxpr: """ Convert a sequence of JaxprEqn into a ClosedJaxpr. Args: eqns: Sequence of Jaxpr equations to convert invars: Input variables. If None, will be inferred from equations outvars: Output variables. If None, will be inferred from equations constvars: Constant variables. If None, will be automatically extracted from equations consts: Constant values corresponding to constvars. If None, defaults to empty list Returns: ClosedJaxpr: A ClosedJaxpr object constructed from the equations Note: If constvars are automatically extracted from equations but no consts are provided, the resulting ClosedJaxpr will have empty consts list. This may cause runtime errors if the equations actually depend on these constants. In such cases, you should explicitly provide both constvars and consts from the original jaxpr. """ # Create jaxpr (will automatically extract constvars if not provided) jaxpr = eqns_to_jaxpr(eqns, invars, outvars, constvars) # Handle consts if consts is None: # If no consts provided, create empty list # This is safe if there are no constvars, but may cause errors otherwise consts = [] else: consts = list(consts) # Verify consts length matches constvars length if len(consts) != len(jaxpr.constvars): raise ValueError( f"consts length ({len(consts)}) does not match constvars length ({len(jaxpr.constvars)})" ) return ClosedJaxpr(jaxpr, consts)