GradientTransform#
- class brainstate.transform.GradientTransform(target, transform, grad_states=None, argnums=None, return_value=False, has_aux=False, transform_params=None, check_states=True, debug_nan=False)#
Automatic Differentiation Transformations for the
Statesystem.This class implements gradient transformations for functions that operate on State objects. It allows for flexible configuration of gradient computation with respect to specified states and function arguments.
- Parameters:
target (
Callable) – The function to be transformed.transform (
Callable) – The transformation function to apply.grad_states (
State|Sequence[State] |Dict[str,State] |None) – States to compute gradients for.argnums (
int|Sequence[int] |None) – Indices of arguments to differentiate with respect to.return_value (
bool) – Whether to return the function’s value along with gradients.has_aux (
bool) – Whether the function returns auxiliary data.transform_params (
Dict[str,Any] |None) – Additional parameters for the transformation function.check_states (
bool) – Whether to check that all grad_states are found in the function.debug_nan (
bool) – Whether to enable NaN debugging. When True, raises RuntimeError with detailed diagnostics if NaN is detected during gradient computation.
- target#
The function to be transformed.
- Type:
callable
- stateful_target#
A wrapper around the target function for state management.
- Type:
Examples
Basic gradient computation with states:
>>> import brainstate >>> import jax.numpy as jnp >>> >>> # Create states >>> weight = brainstate.State(jnp.array([[1.0, 2.0], [3.0, 4.0]])) >>> bias = brainstate.State(jnp.array([0.5, -0.5])) >>> >>> def loss_fn(x): ... y = x @ weight.value + bias.value ... return jnp.sum(y ** 2) >>> >>> # Create gradient transform >>> grad_transform = brainstate.transform.GradientTransform( ... target=loss_fn, ... transform=jax.grad, ... grad_states=[weight, bias] ... ) >>> >>> # Compute gradients >>> x = jnp.array([1.0, 2.0]) >>> grads = grad_transform(x)
With function arguments and auxiliary data:
>>> def loss_fn_with_aux(x, scale): ... y = x @ weight.value + bias.value ... loss = jnp.sum((y * scale) ** 2) ... return loss, {"predictions": y, "scale": scale} >>> >>> grad_transform = brainstate.transform.GradientTransform( ... target=loss_fn_with_aux, ... transform=jax.grad, ... grad_states=[weight, bias], ... argnums=[0, 1], # gradient w.r.t x and scale ... has_aux=True, ... return_value=True ... ) >>> >>> grads, loss_value, aux_data = grad_transform(x, 2.0)