brainstate.transform.jvp#
- brainstate.transform.jvp(fun, primals, tangents, *, has_aux=False)#
Compute a state-aware Jacobian-vector product (forward-mode autodiff).
Trace
fun(which may read and writeStateobjects) into a pure function, applyjax.jvp()with respect to the positional arguments, and re-thread any written states back into theirStateobjects.States are treated as constants for the forward pass (zero tangent); states written inside
funare still updated. Differentiating with respect to state values (state tangents) is a future enhancement.- Parameters:
fun (
Callable) – Function to be differentiated. It may read/writeStateobjects. Ifhas_auxisTrueit must return(output, aux).primals (
Sequence) – The positional arguments at which to evaluatefun, as a tuple/list matchingfun’s signature.tangents (
Sequence) – Tangent vectors, a tuple/list with the same structure asprimals.has_aux (
bool) – Whetherfunreturns(output, aux). The auxiliary data is not differentiated.
- Return type:
- Returns:
primal_out (PyTree) – The value
fun(*primals)(the first element ifhas_aux).tangent_out (PyTree) – The directional derivative of
funalongtangents.aux (PyTree) – Returned only when
has_auxisTrue.
- Raises:
TypeError – If
primalsortangentsis not a tuple/list.
Examples
>>> import brainstate >>> import jax.numpy as jnp >>> def f(x): ... return jnp.sum(x ** 2) >>> out, tangent = brainstate.transform.jvp(f, (jnp.array([1.0, 2.0]),), ... (jnp.array([1.0, 1.0]),))