brainstate.transform.make_jaxpr

Contents

brainstate.transform.make_jaxpr#

brainstate.transform.make_jaxpr(fun, static_argnums=(), static_argnames=(), axis_env=None, return_shape=False, return_only_write=False)#

Creates a function that produces its jaxpr given example args.

A jaxpr is JAX’s intermediate representation for program traces. The jaxpr language is based on the simply-typed first-order lambda calculus with let-bindings. make_jaxpr() adapts a function to return its jaxpr, which we can inspect to understand what JAX is doing internally. The jaxpr returned is a trace of fun abstracted to ShapedArray level. Other levels of abstraction exist internally.

Parameters:
  • fun (Callable) – The function whose jaxpr is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.

  • static_argnums (int | Iterable[int]) – See the jax.jit() docstring.

  • static_argnames (str | Iterable[str]) – See the jax.jit() docstring.

  • axis_env (Sequence[tuple[Hashable, int]] | None) – A sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications of jax.pmap().

  • return_shape (bool) – If True, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as the output of fun and where the leaves are objects with shape, dtype, and named_shape attributes representing the corresponding types of the output leaves.

  • return_only_write (bool) –

    If True, only return states that were written to during execution (not just read). This can reduce memory usage when you only care about modified states.

    Note

    This defaults to False (unlike StatefulFunction which defaults to True) because make_jaxpr is primarily used for inspection, where seeing all state flows is typically desired.

Returns:

A wrapped version of fun that when applied to example arguments returns a ClosedJaxpr representation of fun on those arguments. If the argument return_shape is True, then the returned function instead returns a pair where the first element is the ClosedJaxpr representation of fun and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of fun.

Return type:

Callable[..., Tuple[ClosedJaxpr, Tuple[State, ...]] | Tuple[ClosedJaxpr, Tuple[State, ...], PyTree]]

Examples

Basic usage:

>>> import jax
>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...     return jnp.sin(jnp.cos(x))
>>>
>>> # Create jaxpr maker
>>> jaxpr_maker = brainstate.transform.make_jaxpr(f)
>>> jaxpr, states = jaxpr_maker(3.0)

With gradient:

>>> jaxpr_grad_maker = brainstate.transform.make_jaxpr(jax.grad(f))
>>> jaxpr, states = jaxpr_grad_maker(3.0)

With shape information:

>>> jaxpr_maker_with_shape = brainstate.transform.make_jaxpr(f, return_shape=True)
>>> jaxpr, states, shapes = jaxpr_maker_with_shape(3.0)

With stateful function:

>>> state = brainstate.State(jnp.array([1.0, 2.0]))
>>>
>>> def stateful_f(x):
...     state.value += x
...     return state.value
>>>
>>> jaxpr_maker = brainstate.transform.make_jaxpr(stateful_f)
>>> jaxpr, states = jaxpr_maker(jnp.array([0.5, 0.5]))