clip_grad_norm

Contents

clip_grad_norm#

class brainstate.nn.clip_grad_norm(grad, max_norm, norm_type=2.0, return_norm=False)[source]#

Clip gradient norm of a PyTree of parameters.

The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are scaled if their norm exceeds the specified maximum.

Parameters:
  • grad (PyTree) – A PyTree structure (nested dict, list, tuple, etc.) containing JAX arrays representing gradients to be normalized.

  • max_norm (float | Array) – Maximum allowed norm of the gradients. If the computed norm exceeds this value, gradients will be scaled down proportionally.

  • norm_type (int | float | str | None) –

    Type of the p-norm to compute. Default is 2.0 (L2 norm). Can be:

    • float: p-norm for any p >= 1

    • ’inf’ or jnp.inf: infinity norm (maximum absolute value)

    • ’-inf’ or -jnp.inf: negative infinity norm (minimum absolute value)

    • int: integer p-norm

    • None: defaults to 2.0 (Euclidean norm)

  • return_norm (bool) – If True, returns a tuple (clipped_grad, total_norm). If False, returns only clipped_grad. Default is False.

Return type:

PyTree | tuple[PyTree, Array]

Returns:

  • clipped_grad (PyTree) – The input gradient structure with norms clipped to max_norm.

  • total_norm (jax.Array, optional) – The computed norm of the gradients before clipping. Only returned if return_norm=True.

Notes

The gradient clipping is performed as:

\[g_{\text{clipped}} = g \cdot \min\left(1, \frac{\text{max\_norm}}{\|g\|_p}\right)\]

where \(\|g\|_p\) is the p-norm of the concatenated gradient vector.

Examples

>>> import jax.numpy as jnp
>>> import brainstate

>>> # Simple gradient clipping without returning norm
>>> grads = {'w': jnp.array([3.0, 4.0]), 'b': jnp.array([12.0])}
>>> clipped_grads = brainstate.nn.clip_grad_norm(grads, max_norm=5.0)
>>> print(f"Clipped w: {clipped_grads['w']}")
Clipped w: [1.1538461 1.5384616]

>>> # Gradient clipping with norm returned
>>> grads = {'w': jnp.array([3.0, 4.0]), 'b': jnp.array([12.0])}
>>> clipped_grads, norm = brainstate.nn.clip_grad_norm(grads, max_norm=5.0, return_norm=True)
>>> print(f"Original norm: {norm:.2f}")
Original norm: 13.00

>>> # Using different norm types
>>> grads = {'layer1': jnp.array([[-2.0, 3.0], [1.0, -4.0]])}
>>>
>>> # L2 norm (default)
>>> clipped_l2, norm_l2 = brainstate.nn.clip_grad_norm(grads, max_norm=3.0, norm_type=2, return_norm=True)
>>> print(f"L2 norm: {norm_l2:.2f}")
L2 norm: 5.48
>>>
>>> # L1 norm
>>> clipped_l1, norm_l1 = brainstate.nn.clip_grad_norm(grads, max_norm=5.0, norm_type=1, return_norm=True)
>>> print(f"L1 norm: {norm_l1:.2f}")
L1 norm: 10.00
>>>
>>> # Infinity norm
>>> clipped_inf, norm_inf = brainstate.nn.clip_grad_norm(grads, max_norm=2.0, norm_type='inf', return_norm=True)
>>> print(f"Inf norm: {norm_inf:.2f}")
Inf norm: 4.00