brainstate.transform.fn_to_python_code#
- brainstate.transform.fn_to_python_code(fn, *args, **kwargs)[source]#
Given a function which is defined by jax primitives and the function arguments, return the Python code that would be generated by JAX for that function.
- Parameters:
fn – The function to generate code for
args – The positional arguments to the function
kwargs – The keyword arguments to the function
- Returns:
The Python code that would be generated by JAX for that function