brainstate.transform.jacfwd#
- brainstate.transform.jacfwd(func, grad_states=None, argnums=None, has_aux=None, return_value=False, holomorphic=False, unit_aware=False, check_states=True, **kwargs)#
Extending automatic Jacobian (forward-mode) of
functo classes.This function extends the JAX official
jacfwdto make automatic jacobian computation on functions and class functions. Moreover, it supports returning value (“return_value”) and returning auxiliary data (“has_aux”).When
grad_statesis Nonehas_aux=False+return_value=False=>arg_grads.has_aux=True+return_value=False=>(arg_grads, aux_data).has_aux=False+return_value=True=>(arg_grads, loss_value).has_aux=True+return_value=True=>(arg_grads, loss_value, aux_data).
When
grad_statesis not None andargnumsis Nonehas_aux=False+return_value=False=>var_grads.has_aux=True+return_value=False=>(var_grads, aux_data).has_aux=False+return_value=True=>(var_grads, loss_value).has_aux=True+return_value=True=>(var_grads, loss_value, aux_data).
When
grad_statesis not None andargnumsis not Nonehas_aux=False+return_value=False=>(var_grads, arg_grads).has_aux=True+return_value=False=>((var_grads, arg_grads), aux_data).has_aux=False+return_value=True=>((var_grads, arg_grads), loss_value).has_aux=True+return_value=True=>((var_grads, arg_grads), loss_value, aux_data).
- Parameters:
func (
Callable)grad_states (
State|Sequence[State] |Dict[str,State] |None) – The variables infuncto take their gradients.has_aux (
bool|None) – Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.return_value (
bool) – Whether return the loss value.argnums (
int|Sequence[int] |None) – positional argument(s) to differentiate with respect to (default0).holomorphic (
bool) – holomorphic. Default False.unit_aware (
bool) – mode. Default False.check_states (
bool) – Whether to check the states ingrad_states. Default True.
- Returns:
obj – The transformed object.
- Return type: