brainstate.transform.fn_to_python_code

brainstate.transform.fn_to_python_code#

brainstate.transform.fn_to_python_code(fn, *args, **kwargs)[source]#

Given a function which is defined by jax primitives and the function arguments, return the Python code that would be generated by JAX for that function.

Parameters:
  • fn – The function to generate code for

  • args – The positional arguments to the function

  • kwargs – The keyword arguments to the function

Returns:

The Python code that would be generated by JAX for that function