brainstate.transform.sofo_grad#
- brainstate.transform.sofo_grad(fun, loss_fn, grad_states=None, argnums=None, has_aux=None, return_value=False, check_states=True, loss='mse', tangent_size=100, damping=1e-05, key=None)#
Second-order forward-mode optimization to compute loss and gradient.
When
grad_statesis Nonehas_aux=False+return_loss=False=>arg_grads.has_aux=True+return_loss=False=>(arg_grads, aux_data).has_aux=False+return_loss=True=>(arg_grads, fn_value).has_aux=True+return_loss=True=>(arg_grads, fn_value, aux_data).
When
grad_statesis not None andargnumsis Nonehas_aux=False+return_loss=False=>var_grads.has_aux=True+return_loss=False=>(var_grads, aux_data).has_aux=False+return_loss=True=>(var_grads, fn_value).has_aux=True+return_loss=True=>(var_grads, fn_value, aux_data).
When
grad_statesis not None andargnumsis not Nonehas_aux=False+return_loss=False=>(var_grads, arg_grads).has_aux=True+return_loss=False=>((var_grads, arg_grads), aux_data).has_aux=False+return_loss=True=>((var_grads, arg_grads), fn_value).has_aux=True+return_loss=True=>((var_grads, arg_grads), fn_value, aux_data).
- Parameters:
fun (
Callable) – The scalar-valued function to be differentiated.grad_states (
State|Sequence[State] |Dict[str,State] |None) – The variables in fun to take their gradients.argnums (
int|Sequence[int] |None) – Specifies which positional argument(s) to differentiate with respect to.has_aux (
bool|None) – Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data.return_value (
bool|None) – Indicates whether to return the value of the function along with the gradient.check_states (
bool) – Whether to check that all grad_states are found in the function.loss (
str) – Loss function to use. Supported values are ‘mse’ and ‘ce’.
- Returns:
A function which computes the gradient of fun. The function takes the same arguments as fun, but returns the gradient instead. If has_aux is True, the function returns a pair where the first element is the gradient and the second element is the auxiliary data. If return_loss is True, the function returns a pair where the first element is the gradient and the second element is the value of the function.
- Return type: