brainstate.transform.hessian#
- brainstate.transform.hessian(func, grad_states=None, argnums=None, return_value=False, holomorphic=False, has_aux=None, unit_aware=False, check_states=True, **kwargs)#
Hessian of
funcas a dense array.When
grad_statesis Nonehas_aux=False+return_value=False=>arg_grads.has_aux=True+return_value=False=>(arg_grads, aux_data).has_aux=False+return_value=True=>(arg_grads, loss_value).has_aux=True+return_value=True=>(arg_grads, loss_value, aux_data).
When
grad_statesis not None andargnumsis Nonehas_aux=False+return_value=False=>var_grads.has_aux=True+return_value=False=>(var_grads, aux_data).has_aux=False+return_value=True=>(var_grads, loss_value).has_aux=True+return_value=True=>(var_grads, loss_value, aux_data).
When
grad_statesis not None andargnumsis not Nonehas_aux=False+return_value=False=>(var_grads, arg_grads).has_aux=True+return_value=False=>((var_grads, arg_grads), aux_data).has_aux=False+return_value=True=>((var_grads, arg_grads), loss_value).has_aux=True+return_value=True=>((var_grads, arg_grads), loss_value, aux_data).
- Parameters:
func (
Callable) – Function whose Hessian is to be computed. Its arguments at positions specified byargnumsshould be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.grad_states (
State|Sequence[State] |Dict[str,State] |None) – The variables required to compute their gradients.argnums (
int|Sequence[int] |None) – Specifies which positional argument(s) to differentiate with respect to (default0).holomorphic (
bool) – Indicates whetherfunis promised to be holomorphic. Default False.return_value (
bool) – Whether return the hessian values.has_aux (
bool|None) – Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.unit_aware (
bool) – mode. Default False.check_states (
bool) – Whether to check the states ingrad_states. Default True.
- Returns:
obj – The transformed object.
- Return type: