brainstate.transform.jvp

Contents

brainstate.transform.jvp#

brainstate.transform.jvp(fun, primals, tangents, *, has_aux=False)#

Compute a state-aware Jacobian-vector product (forward-mode autodiff).

Trace fun (which may read and write State objects) into a pure function, apply jax.jvp() with respect to the positional arguments, and re-thread any written states back into their State objects.

States are treated as constants for the forward pass (zero tangent); states written inside fun are still updated. Differentiating with respect to state values (state tangents) is a future enhancement.

Parameters:
  • fun (Callable) – Function to be differentiated. It may read/write State objects. If has_aux is True it must return (output, aux).

  • primals (Sequence) – The positional arguments at which to evaluate fun, as a tuple/list matching fun’s signature.

  • tangents (Sequence) – Tangent vectors, a tuple/list with the same structure as primals.

  • has_aux (bool) – Whether fun returns (output, aux). The auxiliary data is not differentiated.

Return type:

tuple

Returns:

  • primal_out (PyTree) – The value fun(*primals) (the first element if has_aux).

  • tangent_out (PyTree) – The directional derivative of fun along tangents.

  • aux (PyTree) – Returned only when has_aux is True.

Raises:

TypeError – If primals or tangents is not a tuple/list.

See also

vjp, jacfwd, fwd_grad

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>> def f(x):
...     return jnp.sum(x ** 2)
>>> out, tangent = brainstate.transform.jvp(f, (jnp.array([1.0, 2.0]),),
...                                          (jnp.array([1.0, 1.0]),))