brainstate.transform.fwd_grad#
- brainstate.transform.fwd_grad(func=<brainstate.typing.Missing object>, grad_states=None, argnums=None, return_value=False, has_aux=None, tangent_size=None, drct_der_clip=None, key=None, **kwargs)#
Take forward first-order gradients for function
func.Same as
grad(),jacrev(), andjacfwd(), the returns in this function are different for different argument settings.When
grad_statesis Nonehas_aux=False+return_value=False=>arg_grads.has_aux=True+return_value=False=>(arg_grads, aux_data).has_aux=False+return_value=True=>(arg_grads, loss_value).has_aux=True+return_value=True=>(arg_grads, loss_value, aux_data).
When
grad_statesis not None andargnumsis Nonehas_aux=False+return_value=False=>var_grads.has_aux=True+return_value=False=>(var_grads, aux_data).has_aux=False+return_value=True=>(var_grads, loss_value).has_aux=True+return_value=True=>(var_grads, loss_value, aux_data).
When
grad_statesis not None andargnumsis not Nonehas_aux=False+return_value=False=>(var_grads, arg_grads).has_aux=True+return_value=False=>((var_grads, arg_grads), aux_data).has_aux=False+return_value=True=>((var_grads, arg_grads), loss_value).has_aux=True+return_value=True=>((var_grads, arg_grads), loss_value, aux_data).
- Parameters:
func (
Callable) – Function whose gradient is to be computed.grad_states (
State|Sequence[State] |Dict[str,State] |None) – The variables infuncto take their gradients.argnums (
int|Sequence[int] |None) – Specifies which positional argument(s) to differentiate with respect to.return_value (
bool) – Whether to return the loss value.has_aux (
bool|None) – Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data.unit_aware (bool, default False) – Whether to return the gradient in the unit-aware mode.
check_states (bool, default True) – Whether to check that all grad_states are found in the function.
- Returns:
The vector gradient function.
- Return type:
Examples
Basic vector gradient computation:
>>> import brainstate >>> import jax.numpy as jnp >>> >>> # Vector-valued function >>> def f(x): ... return jnp.array([x[0]**2, x[1]**3, x[0]*x[1]]) >>> >>> vector_grad_f = brainstate.transform.vector_grad(f) >>> x = jnp.array([2.0, 3.0]) >>> gradients = vector_grad_f(x) # Shape: (3, 2)
With states:
>>> params = brainstate.State(jnp.array([1.0, 2.0])) >>> >>> def model(x): ... return jnp.array([ ... x * params.value[0], ... x**2 * params.value[1] ... ]) >>> >>> vector_grad_fn = brainstate.transform.vector_grad( ... model, grad_states=[params] ... ) >>> x = 3.0 >>> param_grads = vector_grad_fn(x)