GradientTransform#

class brainstate.transform.GradientTransform(target, transform, grad_states=None, argnums=None, return_value=False, has_aux=False, transform_params=None, check_states=True, debug_nan=False)#

Automatic Differentiation Transformations for the State system.

This class implements gradient transformations for functions that operate on State objects. It allows for flexible configuration of gradient computation with respect to specified states and function arguments.

Parameters:
  • target (Callable) – The function to be transformed.

  • transform (Callable) – The transformation function to apply.

  • grad_states (State | Sequence[State] | Dict[str, State] | None) – States to compute gradients for.

  • argnums (int | Sequence[int] | None) – Indices of arguments to differentiate with respect to.

  • return_value (bool) – Whether to return the function’s value along with gradients.

  • has_aux (bool) – Whether the function returns auxiliary data.

  • transform_params (Dict[str, Any] | None) – Additional parameters for the transformation function.

  • check_states (bool) – Whether to check that all grad_states are found in the function.

  • debug_nan (bool) – Whether to enable NaN debugging. When True, raises RuntimeError with detailed diagnostics if NaN is detected during gradient computation.

target#

The function to be transformed.

Type:

callable

stateful_target#

A wrapper around the target function for state management.

Type:

StatefulFunction

raw_argnums#

The original argnums specified by the user.

Type:

int, sequence of int, or None

true_argnums#

The adjusted argnums used internally.

Type:

int or tuple of int

return_value#

Whether to return the function’s value along with gradients.

Type:

bool

has_aux#

Whether the function returns auxiliary data.

Type:

bool

Examples

Basic gradient computation with states:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create states
>>> weight = brainstate.State(jnp.array([[1.0, 2.0], [3.0, 4.0]]))
>>> bias = brainstate.State(jnp.array([0.5, -0.5]))
>>>
>>> def loss_fn(x):
...     y = x @ weight.value + bias.value
...     return jnp.sum(y ** 2)
>>>
>>> # Create gradient transform
>>> grad_transform = brainstate.transform.GradientTransform(
...     target=loss_fn,
...     transform=jax.grad,
...     grad_states=[weight, bias]
... )
>>>
>>> # Compute gradients
>>> x = jnp.array([1.0, 2.0])
>>> grads = grad_transform(x)

With function arguments and auxiliary data:

>>> def loss_fn_with_aux(x, scale):
...     y = x @ weight.value + bias.value
...     loss = jnp.sum((y * scale) ** 2)
...     return loss, {"predictions": y, "scale": scale}
>>>
>>> grad_transform = brainstate.transform.GradientTransform(
...     target=loss_fn_with_aux,
...     transform=jax.grad,
...     grad_states=[weight, bias],
...     argnums=[0, 1],  # gradient w.r.t x and scale
...     has_aux=True,
...     return_value=True
... )
>>>
>>> grads, loss_value, aux_data = grad_transform(x, 2.0)