brainstate.transform.vjp

Contents

brainstate.transform.vjp#

brainstate.transform.vjp(fun, *primals, grad_states=None, argnums=0, has_aux=False)#

Compute a state-aware vector-Jacobian product (reverse-mode autodiff).

Trace fun (which may read and write State objects) into a pure function, apply jax.vjp(), re-thread any written states back into their State objects, and return the primal output together with a pullback (a.k.a. cotangent map or backward function).

Calling the returned vjp_fn with a cotangent v (a tangent of the output) yields v @ J, where J is the Jacobian of fun evaluated at primals. This is the building block of reverse-mode autodiff: a single forward trace amortizes arbitrarily many backward passes, so it is the natural primitive for full Jacobians (one row per output) and for higher-order products such as Hessian-vector products.

What vjp_fn returns depends on grad_states and argnums, mirroring brainstate.transform.grad():

grad_states

argnums

vjp_fn(v) returns

None

int / sequence

arg_cotangents

provided

int / sequence

(state_cts, arg_cts)

provided

None (or no primals)

state_cotangents

Parameters:
  • fun (Callable) – Function to be differentiated. It may read and/or write State objects. If has_aux is True it must return (output, aux).

  • *primals – Positional arguments at which to evaluate fun and its pullback. May be omitted entirely when differentiating only with respect to grad_states (e.g. a parameterized model whose inputs are closed over or supplied through states).

  • grad_states (State | Sequence[State] | Dict[str, State] | None) – States to compute cotangents for. The returned state cotangents follow the structure of grad_states (an unwrapped array for a single State, a list for a sequence, a matching dict for a mapping). When given, the pullback also returns argument cotangents unless argnums is None or no positional primals are supplied.

  • argnums (int | Sequence[int] | None) – Positional argument(s) to differentiate with respect to. A single int yields an unwrapped argument cotangent; a sequence yields a tuple, one entry per index. None disables argument differentiation so the pullback returns only state cotangents (requires grad_states). If fun is called with no positional primals, argument differentiation is disabled automatically.

  • has_aux (bool) – Whether fun returns (output, aux). The auxiliary data is not differentiated but is returned to the caller.

Return type:

tuple

Returns:

  • primal_out (PyTree) – The value fun(*primals) (the first element if has_aux).

  • vjp_fn (callable) – Pullback mapping a cotangent of primal_out to input cotangents. The cotangent passed to vjp_fn must have the same pytree structure, shape, and dtype as primal_out. See the table above for the return structure.

  • aux (PyTree) – Returned only when has_aux is True.

Raises:
  • TypeError – If any entry of grad_states is not a State.

  • ValueError – If a grad_state is never read by fun (so its cotangent is undefined), if argnums is out of range for the given primals, or if there is nothing to differentiate (no primals and no grad_states).

See also

jvp, grad, jacrev

Notes

States that fun writes are re-threaded back into their State objects after the forward trace (the same side effects fun would have produced if called directly). Differentiation is always taken with respect to the input value a state held on entry, so reading-then-writing a grad_state is well defined.

Because vjp traces fun once and returns a reusable pullback, it is strictly more flexible than grad() when you need (a) multiple backward passes, (b) a non-scalar output, or (c) a custom cotangent v. Evaluating vjp_fn(1.0) on a scalar-output fun reproduces grad() exactly.

Examples

Plain reverse-mode autodiff (no states). vjp matches jax.vjp() on a pure function; with a scalar int argnums the argument cotangent is returned unwrapped.

>>> import brainstate
>>> import jax.numpy as jnp
>>> def f(x):
...     return jnp.sum(x ** 2)
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> out, vjp_fn = brainstate.transform.vjp(f, x)
>>> out
Array(14., dtype=float32)
>>> vjp_fn(1.0)            # d/dx sum(x**2) = 2x
Array([2., 4., 6.], dtype=float32)

Gradients with respect to states. Pass grad_states to also obtain state cotangents. The pullback then returns (state_cts, arg_cts).

>>> w = brainstate.State(jnp.array([2.0, 3.0]))
>>> def loss(x):
...     return jnp.sum(w.value * x)
>>> x = jnp.array([5.0, 7.0])
>>> out, vjp_fn = brainstate.transform.vjp(loss, x, grad_states=w)
>>> state_ct, arg_ct = vjp_fn(1.0)
>>> state_ct              # d/dw sum(w*x) = x
Array([5., 7.], dtype=float32)
>>> arg_ct                # d/dx sum(w*x) = w
Array([2., 3.], dtype=float32)

State-only gradients (no differentiable argument). This is the typical neural-network case: the loss closes over the trainable parameters, so the pullback returns just the state cotangents.

>>> weight = brainstate.State(jnp.array([[1.0, 2.0], [3.0, 4.0]]))
>>> bias = brainstate.State(jnp.array([0.0, 0.0]))
>>> x = jnp.array([1.0, 1.0])
>>> def predict_loss():
...     y = x @ weight.value + bias.value
...     return jnp.sum(y ** 2)
>>> out, vjp_fn = brainstate.transform.vjp(predict_loss, grad_states=[weight, bias])
>>> grads = vjp_fn(1.0)           # list of state cotangents, no arg cotangent
>>> [g.shape for g in grads]
[(2, 2), (2,)]

Auxiliary data and state write-back. has_aux=True returns the side output untouched; states written inside fun keep their new values.

>>> counter = brainstate.State(jnp.array(0.0))
>>> def f(x):
...     counter.value = counter.value + 1.0
...     return jnp.sum(x ** 2), {'mean': jnp.mean(x)}
>>> x = jnp.array([1.0, 2.0])
>>> out, vjp_fn, aux = brainstate.transform.vjp(f, x, has_aux=True)
>>> aux['mean']
Array(1.5, dtype=float32)
>>> vjp_fn(1.0)
Array([2., 4.], dtype=float32)
>>> float(counter.value)  # write re-threaded back into the State
1.0

Full Jacobian by reusing the pullback. One trace, many backward passes: map the pullback over the rows of the identity to build the Jacobian.

>>> import jax
>>> def f(x):
...     return jnp.array([jnp.sum(x), jnp.sum(x ** 2)])
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> out, vjp_fn = brainstate.transform.vjp(f, x)
>>> jac = jax.vmap(vjp_fn)(jnp.eye(2))
>>> jac                   # rows are gradients of each output
Array([[1., 1., 1.],
       [2., 4., 6.]], dtype=float32)

Multiple arguments. A sequence argnums returns a tuple of cotangents, one per requested argument.

>>> def f(x, y):
...     return jnp.sum(x * y)
>>> x = jnp.array([1.0, 2.0])
>>> y = jnp.array([3.0, 4.0])
>>> out, vjp_fn = brainstate.transform.vjp(f, x, y, argnums=(0, 1))
>>> gx, gy = vjp_fn(1.0)
>>> gx, gy                # (d/dx, d/dy) sum(x*y) = (y, x)
(Array([3., 4.], dtype=float32), Array([1., 2.], dtype=float32))