brainstate.transform.hessian

Contents

brainstate.transform.hessian#

brainstate.transform.hessian(func, grad_states=None, argnums=None, return_value=False, holomorphic=False, has_aux=None, unit_aware=False, check_states=True, **kwargs)#

Hessian of func as a dense array.

  1. When grad_states is None

    • has_aux=False + return_value=False => arg_grads.

    • has_aux=True + return_value=False => (arg_grads, aux_data).

    • has_aux=False + return_value=True => (arg_grads, loss_value).

    • has_aux=True + return_value=True => (arg_grads, loss_value, aux_data).

  2. When grad_states is not None and argnums is None

    • has_aux=False + return_value=False => var_grads.

    • has_aux=True + return_value=False => (var_grads, aux_data).

    • has_aux=False + return_value=True => (var_grads, loss_value).

    • has_aux=True + return_value=True => (var_grads, loss_value, aux_data).

  3. When grad_states is not None and argnums is not None

    • has_aux=False + return_value=False => (var_grads, arg_grads).

    • has_aux=True + return_value=False => ((var_grads, arg_grads), aux_data).

    • has_aux=False + return_value=True => ((var_grads, arg_grads), loss_value).

    • has_aux=True + return_value=True => ((var_grads, arg_grads), loss_value, aux_data).

Parameters:
  • func (Callable) – Function whose Hessian is to be computed. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.

  • grad_states (State | Sequence[State] | Dict[str, State] | None) – The variables required to compute their gradients.

  • argnums (int | Sequence[int] | None) – Specifies which positional argument(s) to differentiate with respect to (default 0).

  • holomorphic (bool) – Indicates whether fun is promised to be holomorphic. Default False.

  • return_value (bool) – Whether return the hessian values.

  • has_aux (bool | None) – Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • unit_aware (bool) – mode. Default False.

  • check_states (bool) – Whether to check the states in grad_states. Default True.

Returns:

obj – The transformed object.

Return type:

GradientTransform