brainstate.transform.vjp#
- brainstate.transform.vjp(fun, *primals, grad_states=None, argnums=0, has_aux=False)#
Compute a state-aware vector-Jacobian product (reverse-mode autodiff).
Trace
fun(which may read and writeStateobjects) into a pure function, applyjax.vjp(), re-thread any written states back into theirStateobjects, and return the primal output together with a pullback (a.k.a. cotangent map or backward function).Calling the returned
vjp_fnwith a cotangentv(a tangent of the output) yieldsv @ J, whereJis the Jacobian offunevaluated atprimals. This is the building block of reverse-mode autodiff: a single forward trace amortizes arbitrarily many backward passes, so it is the natural primitive for full Jacobians (one row per output) and for higher-order products such as Hessian-vector products.What
vjp_fnreturns depends ongrad_statesandargnums, mirroringbrainstate.transform.grad():grad_statesargnumsvjp_fn(v)returnsNoneint/ sequencearg_cotangentsprovided
int/ sequence(state_cts, arg_cts)provided
None(or no primals)state_cotangents- Parameters:
fun (
Callable) – Function to be differentiated. It may read and/or writeStateobjects. Ifhas_auxisTrueit must return(output, aux).*primals – Positional arguments at which to evaluate
funand its pullback. May be omitted entirely when differentiating only with respect tograd_states(e.g. a parameterized model whose inputs are closed over or supplied through states).grad_states (
State|Sequence[State] |Dict[str,State] |None) – States to compute cotangents for. The returned state cotangents follow the structure ofgrad_states(an unwrapped array for a singleState, a list for a sequence, a matchingdictfor a mapping). When given, the pullback also returns argument cotangents unlessargnumsisNoneor no positionalprimalsare supplied.argnums (
int|Sequence[int] |None) – Positional argument(s) to differentiate with respect to. A singleintyields an unwrapped argument cotangent; a sequence yields a tuple, one entry per index.Nonedisables argument differentiation so the pullback returns only state cotangents (requiresgrad_states). Iffunis called with no positionalprimals, argument differentiation is disabled automatically.has_aux (
bool) – Whetherfunreturns(output, aux). The auxiliary data is not differentiated but is returned to the caller.
- Return type:
- Returns:
primal_out (PyTree) – The value
fun(*primals)(the first element ifhas_aux).vjp_fn (callable) – Pullback mapping a cotangent of
primal_outto input cotangents. The cotangent passed tovjp_fnmust have the same pytree structure, shape, and dtype asprimal_out. See the table above for the return structure.aux (PyTree) – Returned only when
has_auxisTrue.
- Raises:
ValueError – If a
grad_stateis never read byfun(so its cotangent is undefined), ifargnumsis out of range for the givenprimals, or if there is nothing to differentiate (noprimalsand nograd_states).
Notes
States that
funwrites are re-threaded back into theirStateobjects after the forward trace (the same side effectsfunwould have produced if called directly). Differentiation is always taken with respect to the input value a state held on entry, so reading-then-writing agrad_stateis well defined.Because
vjptracesfunonce and returns a reusable pullback, it is strictly more flexible thangrad()when you need (a) multiple backward passes, (b) a non-scalar output, or (c) a custom cotangentv. Evaluatingvjp_fn(1.0)on a scalar-outputfunreproducesgrad()exactly.Examples
Plain reverse-mode autodiff (no states).
vjpmatchesjax.vjp()on a pure function; with a scalarintargnumsthe argument cotangent is returned unwrapped.>>> import brainstate >>> import jax.numpy as jnp >>> def f(x): ... return jnp.sum(x ** 2) >>> x = jnp.array([1.0, 2.0, 3.0]) >>> out, vjp_fn = brainstate.transform.vjp(f, x) >>> out Array(14., dtype=float32) >>> vjp_fn(1.0) # d/dx sum(x**2) = 2x Array([2., 4., 6.], dtype=float32)
Gradients with respect to states. Pass
grad_statesto also obtain state cotangents. The pullback then returns(state_cts, arg_cts).>>> w = brainstate.State(jnp.array([2.0, 3.0])) >>> def loss(x): ... return jnp.sum(w.value * x) >>> x = jnp.array([5.0, 7.0]) >>> out, vjp_fn = brainstate.transform.vjp(loss, x, grad_states=w) >>> state_ct, arg_ct = vjp_fn(1.0) >>> state_ct # d/dw sum(w*x) = x Array([5., 7.], dtype=float32) >>> arg_ct # d/dx sum(w*x) = w Array([2., 3.], dtype=float32)
State-only gradients (no differentiable argument). This is the typical neural-network case: the loss closes over the trainable parameters, so the pullback returns just the state cotangents.
>>> weight = brainstate.State(jnp.array([[1.0, 2.0], [3.0, 4.0]])) >>> bias = brainstate.State(jnp.array([0.0, 0.0])) >>> x = jnp.array([1.0, 1.0]) >>> def predict_loss(): ... y = x @ weight.value + bias.value ... return jnp.sum(y ** 2) >>> out, vjp_fn = brainstate.transform.vjp(predict_loss, grad_states=[weight, bias]) >>> grads = vjp_fn(1.0) # list of state cotangents, no arg cotangent >>> [g.shape for g in grads] [(2, 2), (2,)]
Auxiliary data and state write-back.
has_aux=Truereturns the side output untouched; states written insidefunkeep their new values.>>> counter = brainstate.State(jnp.array(0.0)) >>> def f(x): ... counter.value = counter.value + 1.0 ... return jnp.sum(x ** 2), {'mean': jnp.mean(x)} >>> x = jnp.array([1.0, 2.0]) >>> out, vjp_fn, aux = brainstate.transform.vjp(f, x, has_aux=True) >>> aux['mean'] Array(1.5, dtype=float32) >>> vjp_fn(1.0) Array([2., 4.], dtype=float32) >>> float(counter.value) # write re-threaded back into the State 1.0
Full Jacobian by reusing the pullback. One trace, many backward passes: map the pullback over the rows of the identity to build the Jacobian.
>>> import jax >>> def f(x): ... return jnp.array([jnp.sum(x), jnp.sum(x ** 2)]) >>> x = jnp.array([1.0, 2.0, 3.0]) >>> out, vjp_fn = brainstate.transform.vjp(f, x) >>> jac = jax.vmap(vjp_fn)(jnp.eye(2)) >>> jac # rows are gradients of each output Array([[1., 1., 1.], [2., 4., 6.]], dtype=float32)
Multiple arguments. A sequence
argnumsreturns a tuple of cotangents, one per requested argument.>>> def f(x, y): ... return jnp.sum(x * y) >>> x = jnp.array([1.0, 2.0]) >>> y = jnp.array([3.0, 4.0]) >>> out, vjp_fn = brainstate.transform.vjp(f, x, y, argnums=(0, 1)) >>> gx, gy = vjp_fn(1.0) >>> gx, gy # (d/dx, d/dy) sum(x*y) = (y, x) (Array([3., 4.], dtype=float32), Array([1., 2.], dtype=float32))