brainstate.transform.jacobian#
- brainstate.transform.jacobian(fun, grad_states=None, argnums=None, has_aux=None, return_value=False, holomorphic=False, allow_int=False, unit_aware=False, check_states=True, **kwargs)#
Extending automatic Jacobian (reverse-mode) of
functo classes.This function extends the JAX official
jacrevto make automatic jacobian computation on functions and class functions. Moreover, it supports returning value (“return_value”) and returning auxiliary data (“has_aux”).When
grad_statesis Nonehas_aux=False+return_value=False=>arg_grads.has_aux=True+return_value=False=>(arg_grads, aux_data).has_aux=False+return_value=True=>(arg_grads, loss_value).has_aux=True+return_value=True=>(arg_grads, loss_value, aux_data).
When
grad_statesis not None andargnumsis Nonehas_aux=False+return_value=False=>var_grads.has_aux=True+return_value=False=>(var_grads, aux_data).has_aux=False+return_value=True=>(var_grads, loss_value).has_aux=True+return_value=True=>(var_grads, loss_value, aux_data).
When
grad_statesis not None andargnumsis not Nonehas_aux=False+return_value=False=>(var_grads, arg_grads).has_aux=True+return_value=False=>((var_grads, arg_grads), aux_data).has_aux=False+return_value=True=>((var_grads, arg_grads), loss_value).has_aux=True+return_value=True=>((var_grads, arg_grads), loss_value, aux_data).
- Parameters:
fun (
Callable) – Function whose Jacobian is to be computed.grad_states (
State|Sequence[State] |Dict[str,State] |None) – The variables infuncto take their gradients.has_aux (
bool|None) – Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.return_value (
bool) – Whether return the loss value.argnums (
int|Sequence[int] |None) – Specifies which positional argument(s) to differentiate with respect to (default0).holomorphic (
bool) – Indicates whetherfunis promised to be holomorphic. Default False.allow_int (
bool) – Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.unit_aware (
bool) – mode. Default False.check_states (
bool) – Whether to check the states ingrad_states. Default True.
- Returns:
fun – The transformed object.
- Return type: