brainstate.transform.eqns_to_closed_jaxpr

brainstate.transform.eqns_to_closed_jaxpr#

brainstate.transform.eqns_to_closed_jaxpr(eqns, invars=None, outvars=None, constvars=None, consts=None)[source]#

Convert a sequence of JaxprEqn into a ClosedJaxpr.

Parameters:
  • eqns (Sequence[JaxprEqn]) – Sequence of Jaxpr equations to convert

  • invars (Sequence[Var]) – Input variables. If None, will be inferred from equations

  • outvars (Sequence[Var]) – Output variables. If None, will be inferred from equations

  • constvars (Sequence[Var]) – Constant variables. If None, will be automatically extracted from equations

  • consts (Sequence) – Constant values corresponding to constvars. If None, defaults to empty list

Returns:

A ClosedJaxpr object constructed from the equations

Return type:

ClosedJaxpr

Note

If constvars are automatically extracted from equations but no consts are provided, the resulting ClosedJaxpr will have empty consts list. This may cause runtime errors if the equations actually depend on these constants. In such cases, you should explicitly provide both constvars and consts from the original jaxpr.