brainstate.transform.grad#
- brainstate.transform.grad(fun=<brainstate.typing.Missing object>, grad_states=None, argnums=None, holomorphic=False, allow_int=False, has_aux=None, return_value=False, unit_aware=False, check_states=True, **kwargs)#
Compute the gradient of a scalar-valued function with respect to its arguments.
When
grad_statesis Nonehas_aux=False+return_value=False=>arg_grads.has_aux=True+return_value=False=>(arg_grads, aux_data).has_aux=False+return_value=True=>(arg_grads, loss_value).has_aux=True+return_value=True=>(arg_grads, loss_value, aux_data).
When
grad_statesis not None andargnumsis Nonehas_aux=False+return_value=False=>var_grads.has_aux=True+return_value=False=>(var_grads, aux_data).has_aux=False+return_value=True=>(var_grads, loss_value).has_aux=True+return_value=True=>(var_grads, loss_value, aux_data).
When
grad_statesis not None andargnumsis not Nonehas_aux=False+return_value=False=>(var_grads, arg_grads).has_aux=True+return_value=False=>((var_grads, arg_grads), aux_data).has_aux=False+return_value=True=>((var_grads, arg_grads), loss_value).has_aux=True+return_value=True=>((var_grads, arg_grads), loss_value, aux_data).
- Parameters:
fun (
Callable) – The scalar-valued function to be differentiated.grad_states (
State|Sequence[State] |Dict[str,State] |None) – The variables in fun to take their gradients.argnums (
int|Sequence[int] |None) – Specifies which positional argument(s) to differentiate with respect to.holomorphic (
bool|None) – Whether fun is promised to be holomorphic.allow_int (
bool|None) – Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0).has_aux (
bool|None) – Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data.return_value (
bool|None) – Indicates whether to return the value of the function along with the gradient.unit_aware (
bool) – Whether to return the gradient in the unit-aware mode.check_states (
bool) – Whether to check that all grad_states are found in the function.debug_nan (bool, default False) – Whether to enable NaN debugging for the gradient computation.
- Returns:
A function which computes the gradient of fun. The function takes the same arguments as fun, but returns the gradient instead. If has_aux is True, the function returns a pair where the first element is the gradient and the second element is the auxiliary data. If return_value is True, the function returns a pair where the first element is the gradient and the second element is the value of the function.
- Return type:
Examples
Basic gradient computation:
>>> import brainstate >>> import jax.numpy as jnp >>> >>> # Simple function gradient >>> def f(x): ... return jnp.sum(x ** 2) >>> >>> grad_f = brainstate.transform.grad(f) >>> x = jnp.array([1.0, 2.0, 3.0]) >>> gradient = grad_f(x)
Gradient with respect to states:
>>> # Create states >>> weight = brainstate.State(jnp.array([1.0, 2.0])) >>> bias = brainstate.State(jnp.array([0.5])) >>> >>> def loss_fn(x): ... prediction = jnp.dot(x, weight.value) + bias.value ... return prediction ** 2 >>> >>> # Compute gradients with respect to states >>> grad_fn = brainstate.transform.grad(loss_fn, grad_states=[weight, bias]) >>> x = jnp.array([1.0, 2.0]) >>> state_grads = grad_fn(x)
With auxiliary data and return value:
>>> def loss_with_aux(x): ... prediction = jnp.dot(x, weight.value) + bias.value ... loss = prediction ** 2 ... return loss, {"prediction": prediction} >>> >>> grad_fn = brainstate.transform.grad( ... loss_with_aux, ... grad_states=[weight, bias], ... has_aux=True, ... return_value=True ... ) >>> grads, loss_value, aux_data = grad_fn(x)