clip_grad_norm#
- class brainstate.nn.clip_grad_norm(grad, max_norm, norm_type=2.0, return_norm=False)[source]#
Clip gradient norm of a PyTree of parameters.
The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are scaled if their norm exceeds the specified maximum.
- Parameters:
grad (
PyTree) – A PyTree structure (nested dict, list, tuple, etc.) containing JAX arrays representing gradients to be normalized.max_norm (
float|Array) – Maximum allowed norm of the gradients. If the computed norm exceeds this value, gradients will be scaled down proportionally.norm_type (
int|float|str|None) –Type of the p-norm to compute. Default is 2.0 (L2 norm). Can be:
float: p-norm for any p >= 1
’inf’ or jnp.inf: infinity norm (maximum absolute value)
’-inf’ or -jnp.inf: negative infinity norm (minimum absolute value)
int: integer p-norm
None: defaults to 2.0 (Euclidean norm)
return_norm (
bool) – If True, returns a tuple (clipped_grad, total_norm). If False, returns only clipped_grad. Default is False.
- Return type:
- Returns:
clipped_grad (PyTree) – The input gradient structure with norms clipped to max_norm.
total_norm (jax.Array, optional) – The computed norm of the gradients before clipping. Only returned if return_norm=True.
Notes
The gradient clipping is performed as:
\[g_{\text{clipped}} = g \cdot \min\left(1, \frac{\text{max\_norm}}{\|g\|_p}\right)\]where \(\|g\|_p\) is the p-norm of the concatenated gradient vector.
Examples
>>> import jax.numpy as jnp >>> import brainstate >>> # Simple gradient clipping without returning norm >>> grads = {'w': jnp.array([3.0, 4.0]), 'b': jnp.array([12.0])} >>> clipped_grads = brainstate.nn.clip_grad_norm(grads, max_norm=5.0) >>> print(f"Clipped w: {clipped_grads['w']}") Clipped w: [1.1538461 1.5384616] >>> # Gradient clipping with norm returned >>> grads = {'w': jnp.array([3.0, 4.0]), 'b': jnp.array([12.0])} >>> clipped_grads, norm = brainstate.nn.clip_grad_norm(grads, max_norm=5.0, return_norm=True) >>> print(f"Original norm: {norm:.2f}") Original norm: 13.00 >>> # Using different norm types >>> grads = {'layer1': jnp.array([[-2.0, 3.0], [1.0, -4.0]])} >>> >>> # L2 norm (default) >>> clipped_l2, norm_l2 = brainstate.nn.clip_grad_norm(grads, max_norm=3.0, norm_type=2, return_norm=True) >>> print(f"L2 norm: {norm_l2:.2f}") L2 norm: 5.48 >>> >>> # L1 norm >>> clipped_l1, norm_l1 = brainstate.nn.clip_grad_norm(grads, max_norm=5.0, norm_type=1, return_norm=True) >>> print(f"L1 norm: {norm_l1:.2f}") L1 norm: 10.00 >>> >>> # Infinity norm >>> clipped_inf, norm_inf = brainstate.nn.clip_grad_norm(grads, max_norm=2.0, norm_type='inf', return_norm=True) >>> print(f"Inf norm: {norm_inf:.2f}") Inf norm: 4.00