brainstate.transform.jacrev

Contents

brainstate.transform.jacrev#

brainstate.transform.jacrev(fun, grad_states=None, argnums=None, has_aux=None, return_value=False, holomorphic=False, allow_int=False, unit_aware=False, check_states=True, **kwargs)#

Extending automatic Jacobian (reverse-mode) of func to classes.

This function extends the JAX official jacrev to make automatic jacobian computation on functions and class functions. Moreover, it supports returning value (“return_value”) and returning auxiliary data (“has_aux”).

  1. When grad_states is None

    • has_aux=False + return_value=False => arg_grads.

    • has_aux=True + return_value=False => (arg_grads, aux_data).

    • has_aux=False + return_value=True => (arg_grads, loss_value).

    • has_aux=True + return_value=True => (arg_grads, loss_value, aux_data).

  2. When grad_states is not None and argnums is None

    • has_aux=False + return_value=False => var_grads.

    • has_aux=True + return_value=False => (var_grads, aux_data).

    • has_aux=False + return_value=True => (var_grads, loss_value).

    • has_aux=True + return_value=True => (var_grads, loss_value, aux_data).

  3. When grad_states is not None and argnums is not None

    • has_aux=False + return_value=False => (var_grads, arg_grads).

    • has_aux=True + return_value=False => ((var_grads, arg_grads), aux_data).

    • has_aux=False + return_value=True => ((var_grads, arg_grads), loss_value).

    • has_aux=True + return_value=True => ((var_grads, arg_grads), loss_value, aux_data).

Parameters:
  • fun (Callable) – Function whose Jacobian is to be computed.

  • grad_states (State | Sequence[State] | Dict[str, State] | None) – The variables in func to take their gradients.

  • has_aux (bool | None) – Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • return_value (bool) – Whether return the loss value.

  • argnums (int | Sequence[int] | None) – Specifies which positional argument(s) to differentiate with respect to (default 0).

  • holomorphic (bool) – Indicates whether fun is promised to be holomorphic. Default False.

  • allow_int (bool) – Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

  • unit_aware (bool) – mode. Default False.

  • check_states (bool) – Whether to check the states in grad_states. Default True.

Returns:

fun – The transformed object.

Return type:

GradientTransform