brainstate.transform.grad

Contents

brainstate.transform.grad#

brainstate.transform.grad(fun=<brainstate.typing.Missing object>, grad_states=None, argnums=None, holomorphic=False, allow_int=False, has_aux=None, return_value=False, unit_aware=False, check_states=True, **kwargs)#

Compute the gradient of a scalar-valued function with respect to its arguments.

  1. When grad_states is None

    • has_aux=False + return_value=False => arg_grads.

    • has_aux=True + return_value=False => (arg_grads, aux_data).

    • has_aux=False + return_value=True => (arg_grads, loss_value).

    • has_aux=True + return_value=True => (arg_grads, loss_value, aux_data).

  2. When grad_states is not None and argnums is None

    • has_aux=False + return_value=False => var_grads.

    • has_aux=True + return_value=False => (var_grads, aux_data).

    • has_aux=False + return_value=True => (var_grads, loss_value).

    • has_aux=True + return_value=True => (var_grads, loss_value, aux_data).

  3. When grad_states is not None and argnums is not None

    • has_aux=False + return_value=False => (var_grads, arg_grads).

    • has_aux=True + return_value=False => ((var_grads, arg_grads), aux_data).

    • has_aux=False + return_value=True => ((var_grads, arg_grads), loss_value).

    • has_aux=True + return_value=True => ((var_grads, arg_grads), loss_value, aux_data).

Parameters:
  • fun (Callable) – The scalar-valued function to be differentiated.

  • grad_states (State | Sequence[State] | Dict[str, State] | None) – The variables in fun to take their gradients.

  • argnums (int | Sequence[int] | None) – Specifies which positional argument(s) to differentiate with respect to.

  • holomorphic (bool | None) – Whether fun is promised to be holomorphic.

  • allow_int (bool | None) – Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0).

  • has_aux (bool | None) – Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data.

  • return_value (bool | None) – Indicates whether to return the value of the function along with the gradient.

  • unit_aware (bool) – Whether to return the gradient in the unit-aware mode.

  • check_states (bool) – Whether to check that all grad_states are found in the function.

  • debug_nan (bool, default False) – Whether to enable NaN debugging for the gradient computation.

Returns:

A function which computes the gradient of fun. The function takes the same arguments as fun, but returns the gradient instead. If has_aux is True, the function returns a pair where the first element is the gradient and the second element is the auxiliary data. If return_value is True, the function returns a pair where the first element is the gradient and the second element is the value of the function.

Return type:

GradientTransform | Callable[[Callable], GradientTransform]

Examples

Basic gradient computation:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> # Simple function gradient
>>> def f(x):
...     return jnp.sum(x ** 2)
>>>
>>> grad_f = brainstate.transform.grad(f)
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> gradient = grad_f(x)

Gradient with respect to states:

>>> # Create states
>>> weight = brainstate.State(jnp.array([1.0, 2.0]))
>>> bias = brainstate.State(jnp.array([0.5]))
>>>
>>> def loss_fn(x):
...     prediction = jnp.dot(x, weight.value) + bias.value
...     return prediction ** 2
>>>
>>> # Compute gradients with respect to states
>>> grad_fn = brainstate.transform.grad(loss_fn, grad_states=[weight, bias])
>>> x = jnp.array([1.0, 2.0])
>>> state_grads = grad_fn(x)

With auxiliary data and return value:

>>> def loss_with_aux(x):
...     prediction = jnp.dot(x, weight.value) + bias.value
...     loss = prediction ** 2
...     return loss, {"prediction": prediction}
>>>
>>> grad_fn = brainstate.transform.grad(
...     loss_with_aux,
...     grad_states=[weight, bias],
...     has_aux=True,
...     return_value=True
... )
>>> grads, loss_value, aux_data = grad_fn(x)