brainstate.transform.make_jaxpr#
- brainstate.transform.make_jaxpr(fun, static_argnums=(), static_argnames=(), axis_env=None, return_shape=False, return_only_write=False)#
Creates a function that produces its jaxpr given example args.
A
jaxpris JAX’s intermediate representation for program traces. Thejaxprlanguage is based on the simply-typed first-order lambda calculus with let-bindings.make_jaxpr()adapts a function to return itsjaxpr, which we can inspect to understand what JAX is doing internally. Thejaxprreturned is a trace offunabstracted toShapedArraylevel. Other levels of abstraction exist internally.- Parameters:
fun (
Callable) – The function whosejaxpris to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.static_argnums (
int|Iterable[int]) – See thejax.jit()docstring.static_argnames (
str|Iterable[str]) – See thejax.jit()docstring.axis_env (
Sequence[tuple[Hashable,int]] |None) – A sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications ofjax.pmap().return_shape (
bool) – IfTrue, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as the output offunand where the leaves are objects withshape,dtype, andnamed_shapeattributes representing the corresponding types of the output leaves.return_only_write (
bool) –If True, only return states that were written to during execution (not just read). This can reduce memory usage when you only care about modified states.
Note
This defaults to
False(unlikeStatefulFunctionwhich defaults toTrue) becausemake_jaxpris primarily used for inspection, where seeing all state flows is typically desired.
- Returns:
A wrapped version of
funthat when applied to example arguments returns aClosedJaxprrepresentation offunon those arguments. If the argumentreturn_shapeisTrue, then the returned function instead returns a pair where the first element is theClosedJaxprrepresentation offunand the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output offun.- Return type:
Callable[...,Tuple[ClosedJaxpr,Tuple[State,...]] |Tuple[ClosedJaxpr,Tuple[State,...],PyTree]]
Examples
Basic usage:
>>> import jax >>> import brainstate >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.sin(jnp.cos(x)) >>> >>> # Create jaxpr maker >>> jaxpr_maker = brainstate.transform.make_jaxpr(f) >>> jaxpr, states = jaxpr_maker(3.0)
With gradient:
>>> jaxpr_grad_maker = brainstate.transform.make_jaxpr(jax.grad(f)) >>> jaxpr, states = jaxpr_grad_maker(3.0)
With shape information:
>>> jaxpr_maker_with_shape = brainstate.transform.make_jaxpr(f, return_shape=True) >>> jaxpr, states, shapes = jaxpr_maker_with_shape(3.0)
With stateful function:
>>> state = brainstate.State(jnp.array([1.0, 2.0])) >>> >>> def stateful_f(x): ... state.value += x ... return state.value >>> >>> jaxpr_maker = brainstate.transform.make_jaxpr(stateful_f) >>> jaxpr, states = jaxpr_maker(jnp.array([0.5, 0.5]))