brainstate.transform.vector_grad

Contents

brainstate.transform.vector_grad#

brainstate.transform.vector_grad(func=<brainstate.typing.Missing object>, grad_states=None, argnums=None, return_value=False, has_aux=None, unit_aware=False, check_states=True, **kwargs)#

Take vector-valued gradients for function func.

Same as grad(), jacrev(), and jacfwd(), the returns in this function are different for different argument settings.

  1. When grad_states is None

    • has_aux=False + return_value=False => arg_grads.

    • has_aux=True + return_value=False => (arg_grads, aux_data).

    • has_aux=False + return_value=True => (arg_grads, loss_value).

    • has_aux=True + return_value=True => (arg_grads, loss_value, aux_data).

  2. When grad_states is not None and argnums is None

    • has_aux=False + return_value=False => var_grads.

    • has_aux=True + return_value=False => (var_grads, aux_data).

    • has_aux=False + return_value=True => (var_grads, loss_value).

    • has_aux=True + return_value=True => (var_grads, loss_value, aux_data).

  3. When grad_states is not None and argnums is not None

    • has_aux=False + return_value=False => (var_grads, arg_grads).

    • has_aux=True + return_value=False => ((var_grads, arg_grads), aux_data).

    • has_aux=False + return_value=True => ((var_grads, arg_grads), loss_value).

    • has_aux=True + return_value=True => ((var_grads, arg_grads), loss_value, aux_data).

Parameters:
  • func (Callable) – Function whose gradient is to be computed.

  • grad_states (State | Sequence[State] | Dict[str, State] | None) – The variables in func to take their gradients.

  • argnums (int | Sequence[int] | None) – Specifies which positional argument(s) to differentiate with respect to.

  • return_value (bool) – Whether to return the loss value.

  • has_aux (bool | None) – Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data.

  • unit_aware (bool) – Whether to return the gradient in the unit-aware mode.

  • check_states (bool) – Whether to check that all grad_states are found in the function.

Returns:

The vector gradient function.

Return type:

GradientTransform | Callable[[Callable], GradientTransform]

Examples

Basic vector gradient computation:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> # Vector-valued function
>>> def f(x):
...     return jnp.array([x[0]**2, x[1]**3, x[0]*x[1]])
>>>
>>> vector_grad_f = brainstate.transform.vector_grad(f)
>>> x = jnp.array([2.0, 3.0])
>>> gradients = vector_grad_f(x)  # Shape: (3, 2)

With states:

>>> params = brainstate.State(jnp.array([1.0, 2.0]))
>>>
>>> def model(x):
...     return jnp.array([
...         x * params.value[0],
...         x**2 * params.value[1]
...     ])
>>>
>>> vector_grad_fn = brainstate.transform.vector_grad(
...     model, grad_states=[params]
... )
>>> x = 3.0
>>> param_grads = vector_grad_fn(x)