brainstate.transform.jaxpr_to_python_code

brainstate.transform.jaxpr_to_python_code#

brainstate.transform.jaxpr_to_python_code(jaxpr, fn_name='generated_function')[source]#

Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr.

Parameters:
  • jaxpr (Jaxpr) – The jaxpr to generate code.

  • fn_name (str) – The name of the function to generate code.

Returns:

The Python code that would be generated by JAX for that jaxpr