brainstate.transform.jaxpr_to_python_code#
- brainstate.transform.jaxpr_to_python_code(jaxpr, fn_name='generated_function')[source]#
Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr.
- Parameters:
jaxpr (
Jaxpr) – The jaxpr to generate code.fn_name (
str) – The name of the function to generate code.
- Returns:
The Python code that would be generated by JAX for that jaxpr