brainstate.transform.sofo_grad

Contents

brainstate.transform.sofo_grad#

brainstate.transform.sofo_grad(fun, loss_fn, grad_states=None, argnums=None, has_aux=None, return_value=False, check_states=True, loss='mse', tangent_size=100, damping=1e-05, key=None)#

Second-order forward-mode optimization to compute loss and gradient.

  1. When grad_states is None

    • has_aux=False + return_loss=False => arg_grads.

    • has_aux=True + return_loss=False => (arg_grads, aux_data).

    • has_aux=False + return_loss=True => (arg_grads, fn_value).

    • has_aux=True + return_loss=True => (arg_grads, fn_value, aux_data).

  2. When grad_states is not None and argnums is None

    • has_aux=False + return_loss=False => var_grads.

    • has_aux=True + return_loss=False => (var_grads, aux_data).

    • has_aux=False + return_loss=True => (var_grads, fn_value).

    • has_aux=True + return_loss=True => (var_grads, fn_value, aux_data).

  3. When grad_states is not None and argnums is not None

    • has_aux=False + return_loss=False => (var_grads, arg_grads).

    • has_aux=True + return_loss=False => ((var_grads, arg_grads), aux_data).

    • has_aux=False + return_loss=True => ((var_grads, arg_grads), fn_value).

    • has_aux=True + return_loss=True => ((var_grads, arg_grads), fn_value, aux_data).

Parameters:
  • fun (Callable) – The scalar-valued function to be differentiated.

  • grad_states (State | Sequence[State] | Dict[str, State] | None) – The variables in fun to take their gradients.

  • argnums (int | Sequence[int] | None) – Specifies which positional argument(s) to differentiate with respect to.

  • has_aux (bool | None) – Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data.

  • return_value (bool | None) – Indicates whether to return the value of the function along with the gradient.

  • check_states (bool) – Whether to check that all grad_states are found in the function.

  • loss (str) – Loss function to use. Supported values are ‘mse’ and ‘ce’.

Returns:

A function which computes the gradient of fun. The function takes the same arguments as fun, but returns the gradient instead. If has_aux is True, the function returns a pair where the first element is the gradient and the second element is the auxiliary data. If return_loss is True, the function returns a pair where the first element is the gradient and the second element is the value of the function.

Return type:

GradientTransform | Callable[[Callable], GradientTransform]