Automatic Differentiation#
BrainState provides a comprehensive automatic differentiation system built on top of JAX, designed specifically for stateful computations. This tutorial focuses on brainstate.transform.grad and related gradient transformations, demonstrating how to compute gradients with respect to function arguments and State objects.
Key Concepts#
BrainState’s gradient system revolves around two key concepts:
argnums: Select which function arguments to differentiate with respect to (inherited from JAX)grad_states: Select whichStateobjects should receive gradients (BrainState’s extension)
Additionally, BrainState uses ParamState to mark trainable parameters in neural networks and provides utilities to discover and manage states in arbitrary functions.
import jax
import jax.numpy as jnp
import brainstate
from brainstate.transform import grad, StateFinder
1. Understanding argnums: Gradients w.r.t. Function Arguments#
The argnums parameter works just like in JAX’s jax.grad. It specifies which positional arguments to differentiate with respect to.
def loss_fn(x, y, scale):
"""Simple loss function with multiple arguments."""
return scale * jnp.sum((x - y) ** 2)
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([0.5, 1.5, 2.5])
scale = 2.0
# Gradient w.r.t. the first argument (x)
grad_fn_x = grad(loss_fn, argnums=0)
grad_x = grad_fn_x(x, y, scale)
print("Gradient w.r.t. x:", grad_x)
# Gradient w.r.t. multiple arguments
grad_fn_xy = grad(loss_fn, argnums=[0, 1])
grad_x, grad_y = grad_fn_xy(x, y, scale)
print("Gradient w.r.t. x:", grad_x)
print("Gradient w.r.t. y:", grad_y)
Gradient w.r.t. x: [2. 2. 2.]
Gradient w.r.t. x: [2. 2. 2.]
Gradient w.r.t. y: [-2. -2. -2.]
2. Understanding grad_states: Gradients w.r.t. State Objects#
2.1 ParamState for Trainable Parameters#
In BrainState, ParamState is used to mark parameters that should receive gradients during training. This is the standard way to define trainable parameters in neural network modules.
class LinearRegressor(brainstate.nn.Module):
"""Simple linear regression model."""
def __init__(self, in_features: int, out_features: int = 1):
super().__init__()
# ParamState marks these as trainable parameters
self.weight = brainstate.ParamState(jnp.zeros((in_features, out_features)))
self.bias = brainstate.ParamState(jnp.zeros((out_features,)))
def __call__(self, x: jax.Array) -> jax.Array:
return x @ self.weight.value + self.bias.value
# Create model and training data
model = LinearRegressor(1)
xs = jnp.linspace(-1.0, 1.0, 5).reshape(-1, 1)
y_true = 3.0 * xs + 1.0
def mse_loss(x: jax.Array, target: jax.Array) -> jax.Array:
"""Mean squared error loss."""
pred = model(x)
return jnp.mean((pred - target) ** 2)
# Compute gradients w.r.t. model parameters
loss_grad = grad(
mse_loss,
grad_states=model.states(brainstate.ParamState), # Get all ParamState instances
return_value=True,
)
param_grads, loss_value = loss_grad(xs, y_true)
print(f"Loss: {float(loss_value):.4f}")
print("\nParameter gradients:")
for path, g in param_grads.items():
print(f" {path}: {g}")
Loss: 5.5000
Parameter gradients:
('bias',): [-2.]
('weight',): [[-3.]]
2.2 Retrieving States from Modules#
BrainState provides two main ways to retrieve states from modules:
module.states(*filter): Get states directly from aModuleinstancebrainstate.graph.treefy_states(node, *filter): Get states from any object (more general)
# Method 1: Using module.states()
params_method1 = model.states(brainstate.ParamState)
print("Using model.states():")
for path, state in params_method1.items():
print(f" {path}: shape={state.value.shape}")
# Method 2: Using brainstate.graph.treefy_states()
params_method2 = brainstate.graph.treefy_states(model, brainstate.ParamState)
print("\nUsing brainstate.graph.treefy_states():")
for path, state in params_method2.to_flat().items():
print(f" {path}: shape={state.value.shape}")
# Both methods return the same states
assert set(params_method1.keys()) == set(params_method2.to_flat().keys())
Using model.states():
('bias',): shape=(1,)
('weight',): shape=(1, 1)
Using brainstate.graph.treefy_states():
('bias',): shape=(1,)
('weight',): shape=(1, 1)
2.3 Using StateFinder for Arbitrary Functions#
Not every function is a nn.Module. For arbitrary functions, you can use StateFinder to discover which states are used inside the function.
# Create some standalone states
scale = brainstate.ParamState(jnp.array(1.5), name="scale")
offset = brainstate.ParamState(jnp.array(-0.2), name="offset")
cache = brainstate.State(jnp.array(0.0), name="cache") # Not a ParamState
def energy(x: jax.Array) -> jax.Array:
"""Energy function using external states."""
shifted = x * scale.value + offset.value
# Update a state to track it as a write operation
scale.value = scale.value + 0.0 # Dummy update to mark as written
cache.value = jnp.sum(shifted) # Write to cache
return jnp.sum(jnp.square(shifted))
# Use StateFinder to discover states used in the function
finder = StateFinder(
energy,
filter=brainstate.ParamState, # Only find ParamState instances
usage='all', # Find both read and write states
return_type='dict', # Return as a dictionary
)
all_param_states = finder(jnp.ones((2,)))
print("States found by StateFinder:")
for name, state in all_param_states.items():
print(f" {name}: {state.name}")
# Now compute gradients w.r.t. these discovered states
energy_grad = grad(
energy,
grad_states=all_param_states,
return_value=True,
)
state_grads, energy_value = energy_grad(jnp.array([1.0, 3.0]))
print(f"\nEnergy: {float(energy_value):.4f}")
print("Gradients:")
for idx, key in enumerate(state_grads):
st = all_param_states[key]
print(f" {key}: {state_grads[key]}")
States found by StateFinder:
0: scale
1: offset
Energy: 20.1800
Gradients:
0: 28.400001525878906
1: 11.200000762939453
2.4 Important Note: Gradients are Not Limited to ParamState#
While ParamState is the standard way to mark trainable parameters, gradient computation works with any State instance. You can compute gradients w.r.t. any State object.
# Create a regular State (not ParamState)
regular_state = brainstate.State(jnp.array(2.0), name="regular_state")
def compute_with_state(x):
return jnp.sum((x * regular_state.value) ** 2)
# Compute gradient w.r.t. regular State
grad_fn = grad(compute_with_state, grad_states=[regular_state])
gradient = grad_fn(jnp.array([1.0, 2.0, 3.0]))
print(f"Gradient w.r.t. regular State: {gradient[0]}")
Gradient w.r.t. regular State: 56.0
3. Combining argnums and grad_states#
You can compute gradients with respect to both function arguments and states simultaneously.
reg_model = LinearRegressor(1)
def penalized_loss(l2_coeff: float, inputs: jax.Array, target: jax.Array) -> jax.Array:
"""Loss with L2 regularization."""
pred = reg_model(inputs)
mse = jnp.mean((pred - target) ** 2)
# L2 penalty on parameters
l2 = jnp.sum(reg_model.weight.value ** 2) + jnp.sum(reg_model.bias.value ** 2)
return mse + l2_coeff * l2
# Compute gradients w.r.t. both states and the first argument
grad_penalized = grad(
penalized_loss,
grad_states=reg_model.states(brainstate.ParamState),
argnums=0, # Also differentiate w.r.t. l2_coeff
return_value=True,
)
(state_grads, coeff_grad), loss_val = grad_penalized(0.5, xs, y_true)
print(f"Loss: {float(loss_val):.4f}")
print(f"Gradient w.r.t. l2_coeff: {float(coeff_grad):.4f}")
print("\nState gradients:")
for path, g in state_grads.items():
print(f" {path}: {g}")
Loss: 5.5000
Gradient w.r.t. l2_coeff: 0.0000
State gradients:
('bias',): [-2.]
('weight',): [[-3.]]
4. Return Value Structures#
All gradient transformations in BrainState share a common signature pattern. The return structure depends on the combination of grad_states, argnums, has_aux, and return_value.
4.1 Basic Return Structures#
When grad_states is None:
has_aux=False+return_value=False→arg_gradshas_aux=True+return_value=False→(arg_grads, aux_data)has_aux=False+return_value=True→(arg_grads, loss_value)has_aux=True+return_value=True→(arg_grads, loss_value, aux_data)
When grad_states is not None and argnums is None:
has_aux=False+return_value=False→var_gradshas_aux=True+return_value=False→(var_grads, aux_data)has_aux=False+return_value=True→(var_grads, loss_value)has_aux=True+return_value=True→(var_grads, loss_value, aux_data)
When both grad_states and argnums are not None:
has_aux=False+return_value=False→(var_grads, arg_grads)has_aux=True+return_value=False→((var_grads, arg_grads), aux_data)has_aux=False+return_value=True→((var_grads, arg_grads), loss_value)has_aux=True+return_value=True→((var_grads, arg_grads), loss_value, aux_data)
List them as a table for clarity:
grad_states |
argnums |
has_aux |
return_value |
result |
|---|---|---|---|---|
|
any |
|
|
|
|
any |
|
|
|
|
any |
|
|
|
|
any |
|
|
|
not |
|
|
|
|
not |
|
|
|
|
not |
|
|
|
|
not |
|
|
|
|
not |
not |
|
|
|
not |
not |
|
|
|
not |
not |
|
|
|
not |
not |
|
|
|
4.2 Complete Example: All Return Options#
example_model = LinearRegressor(1)
def loss_with_metrics(l2_coeff: float, x: jax.Array, target: jax.Array):
"""Loss function that returns auxiliary metrics."""
pred = example_model(x)
mse = jnp.mean((pred - target) ** 2)
l2 = jnp.sum(example_model.weight.value ** 2)
loss = mse + l2_coeff * l2
# Return loss and auxiliary metrics
metrics = {
"mae": jnp.mean(jnp.abs(pred - target)),
"mse": mse,
"l2": l2,
}
return loss, metrics
# Example: grad_states + argnums + has_aux + return_value
grad_complete = grad(
loss_with_metrics,
grad_states=example_model.states(brainstate.ParamState),
argnums=0,
has_aux=True,
return_value=True,
)
((state_grads, coeff_grad), loss_val, aux_metrics) = grad_complete(0.3, xs, y_true)
print(f"Loss: {float(loss_val):.4f}")
print(f"\nGradient w.r.t. l2_coeff: {float(coeff_grad):.4f}")
print("\nState gradients:")
for path, g in state_grads.items():
print(f" {path}: {g}")
print("\nAuxiliary metrics:")
for key, val in aux_metrics.items():
print(f" {key}: {float(val):.4f}")
Loss: 5.5000
Gradient w.r.t. l2_coeff: 0.0000
State gradients:
('bias',): [-2.]
('weight',): [[-3.]]
Auxiliary metrics:
l2: 0.0000
mae: 2.0000
mse: 5.5000
5. Other Gradient Transformations#
BrainState provides several other gradient transformations, all sharing the same signature pattern as grad.
5.1 Vector Gradient#
vector_grad is used for vector-valued functions. It computes the sum of gradients across all output dimensions.
from brainstate.transform import vector_grad
def vector_fun(x):
"""Vector-valued function."""
return jnp.array([x[0] * x[1], jnp.sin(x[0]), x[0]**2 + x[1]**2])
x0 = jnp.array([1.0, 2.0])
# Vector gradient sums gradients across all outputs
vgrad = vector_grad(vector_fun)
result = vgrad(x0)
print("Vector gradient:", result)
Vector gradient: [4.5403023 5. ]
5.2 Jacobians: jacrev and jacfwd#
jacrev: Jacobian using reverse-mode autodiff (efficient for many inputs, few outputs)jacfwd: Jacobian using forward-mode autodiff (efficient for few inputs, many outputs)jacobian: Alias forjacrev
from brainstate.transform import jacrev, jacfwd, jacobian
def multi_output(x):
"""Function with multiple outputs."""
return jnp.array([x[0] * x[1], jnp.sin(x[0]), jnp.exp(x[1])])
x0 = jnp.array([1.0, 2.0])
# Reverse-mode Jacobian
jac_rev = jacrev(multi_output)
result_rev = jac_rev(x0)
print("Jacobian (reverse-mode):")
print(result_rev)
# Forward-mode Jacobian
jac_fwd = jacfwd(multi_output)
result_fwd = jac_fwd(x0)
print("\nJacobian (forward-mode):")
print(result_fwd)
# They should be the same
assert jnp.allclose(result_rev, result_fwd)
# jacobian is an alias for jacrev
jac_alias = jacobian(multi_output)
result_alias = jac_alias(x0)
assert jnp.allclose(result_rev, result_alias)
Jacobian (reverse-mode):
[[2. 1. ]
[0.5403023 0. ]
[0. 7.389056 ]]
Jacobian (forward-mode):
[[2. 1. ]
[0.5403023 0. ]
[0. 7.389056 ]]
5.3 Hessian#
hessian computes second-order derivatives.
from brainstate.transform import hessian
def quadratic(x):
"""Quadratic function."""
return jnp.dot(x, x)
x0 = jnp.array([1.0, 2.0])
hess_fn = hessian(quadratic)
result = hess_fn(x0)
print("Hessian:")
print(result)
# For a quadratic form x^T x, the Hessian is 2*I
expected = 2 * jnp.eye(2)
assert jnp.allclose(result, expected)
Hessian:
[[2. 0.]
[0. 2.]]
5.4 Using Gradient Transformations with States#
# Example: Jacobian with states
jac_model = LinearRegressor(2)
def model_output(x):
"""Multiple outputs from a model."""
return jac_model(x)
# Compute Jacobian w.r.t. model parameters
jac_states = jacrev(
model_output,
grad_states=jac_model.states(brainstate.ParamState)
)
x_input = jnp.array([1.0, 2.0])
param_jacobian = jac_states(x_input)
print("Jacobian w.r.t. parameters:")
for path, jac in param_jacobian.items():
print(f" {path}: shape={jac.shape}")
Jacobian w.r.t. parameters:
('bias',): shape=(1, 1)
('weight',): shape=(1, 2, 1)
6. Custom Gradient Transformations with GradientTransform#
You can create custom gradient transformations by using the GradientTransform class. This allows you to wrap any JAX gradient function while maintaining BrainState’s state-aware behavior.
6.1 Basic Custom Transform#
from brainstate.transform import GradientTransform
def scaled_grad_transform(fun, *, argnums, has_aux, scale):
"""Custom gradient transform that scales gradients."""
# Use JAX's grad as the base transformation
base = jax.grad(fun, argnums=argnums, has_aux=True)
def wrapped(*args, **kwargs):
grads, aux = base(*args, **kwargs)
# Scale all gradients
grads = jax.tree.map(lambda g: scale * g, grads)
return grads, aux
return wrapped
def scaled_grad(
fun,
*,
scale=1.0,
grad_states=None,
argnums=None,
has_aux=False,
return_value=False,
):
"""Create a gradient function with scaled gradients."""
return GradientTransform(
fun,
transform=scaled_grad_transform,
grad_states=grad_states,
argnums=argnums,
has_aux=has_aux,
return_value=return_value,
transform_params={"scale": scale}, # Pass custom parameters
)
# Example usage
custom_model = LinearRegressor(1)
def custom_loss(x, target):
pred = custom_model(x)
return jnp.mean((pred - target) ** 2)
# Use custom scaled gradient
scaled_grad_fn = scaled_grad(
custom_loss,
scale=0.5, # Scale gradients by 0.5
grad_states=custom_model.states(brainstate.ParamState),
)
scaled_grads = scaled_grad_fn(xs, y_true)
print("Scaled gradients:")
for path, g in scaled_grads.items():
print(f" {path}: {g}")
# Compare with unscaled gradients
normal_grad_fn = grad(custom_loss, grad_states=custom_model.states(brainstate.ParamState))
normal_grads = normal_grad_fn(xs, y_true)
print("\nNormal gradients:")
for path, g in normal_grads.items():
print(f" {path}: {g}")
Scaled gradients:
('bias',): [-1.]
('weight',): [[-1.5]]
Normal gradients:
('bias',): [-2.]
('weight',): [[-3.]]
6.2 Advanced: Gradient Clipping Transform#
def clipped_grad_transform(fun, *, argnums, has_aux, max_norm):
"""Custom gradient transform with gradient clipping."""
base = jax.grad(fun, argnums=argnums, has_aux=True)
def wrapped(*args, **kwargs):
grads, aux = base(*args, **kwargs)
# Compute global norm
global_norm = jnp.sqrt(
sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads))
)
# Clip gradients
scale = jnp.minimum(1.0, max_norm / (global_norm + 1e-6))
grads = jax.tree.map(lambda g: scale * g, grads)
return grads, aux
return wrapped
def clipped_grad(
fun,
*,
max_norm=1.0,
grad_states=None,
argnums=None,
has_aux=False,
return_value=False,
):
"""Create a gradient function with gradient clipping."""
return GradientTransform(
fun,
transform=clipped_grad_transform,
grad_states=grad_states,
argnums=argnums,
has_aux=has_aux,
return_value=return_value,
transform_params={"max_norm": max_norm},
)
# Example: gradient clipping
clip_model = LinearRegressor(1)
def clip_loss(x, target):
pred = clip_model(x)
return jnp.mean((pred - target) ** 2)
clipped_grad_fn = clipped_grad(
clip_loss,
max_norm=0.1, # Clip gradients to max norm of 0.1
grad_states=clip_model.states(brainstate.ParamState),
)
clipped_grads = clipped_grad_fn(xs, y_true)
print("Clipped gradients:")
for path, g in clipped_grads.items():
print(f" {path}: {g}")
print(f" norm: {jnp.linalg.norm(g):.4f}")
Clipped gradients:
('bias',): [-0.05547]
norm: 0.0555
('weight',): [[-0.08320501]]
norm: 0.0832
7. Practical Example: Training Loop#
Let’s put everything together in a complete training example.
# Create a fresh model
training_model = LinearRegressor(1)
# Generate training data
true_weight = 3.0
true_bias = 1.0
x_train = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
y_train = true_weight * x_train + true_bias + 0.1 * brainstate.random.normal(size=x_train.shape)
@brainstate.transform.jit
def training_loss(x, y):
"""MSE loss with L2 regularization."""
pred = training_model(x)
mse = jnp.mean((pred - y) ** 2)
l2 = 0.01 * (jnp.sum(training_model.weight.value ** 2) + jnp.sum(training_model.bias.value ** 2))
return mse + l2, {"mse": mse, "l2": l2}
# Create gradient function
loss_grad_fn = grad(
training_loss,
grad_states=training_model.states(brainstate.ParamState),
has_aux=True,
return_value=True,
)
# Training loop
learning_rate = 0.1
num_epochs = 50
print("Training started...")
print(f"Initial weight: {training_model.weight.value}")
print(f"Initial bias: {training_model.bias.value}")
for epoch in range(num_epochs):
# Compute gradients
grads, loss_val, aux = loss_grad_fn(x_train, y_train)
# Update parameters (simple SGD)
for path, state in training_model.states(brainstate.ParamState).items():
grad = grads[path]
state.value = state.value - learning_rate * grad
# Print progress
if (epoch + 1) % 10 == 0:
print(f"\nEpoch {epoch + 1}:")
print(f" Loss: {float(loss_val):.4f}")
print(f" MSE: {float(aux['mse']):.4f}")
print(f" L2: {float(aux['l2']):.6f}")
print(f" Weight: {training_model.weight.value}")
print(f" Bias: {training_model.bias.value}")
print("\nTraining completed!")
print(f"Final weight: {training_model.weight.value} (true: {true_weight})")
print(f"Final bias: {training_model.bias.value} (true: {true_bias})")
Training started...
Initial weight: [[0.]]
Initial bias: [0.]
Epoch 10:
Loss: 0.9506
MSE: 0.9198
L2: 0.030758
Weight: [[1.6285777]]
Bias: [0.9066533]
Epoch 20:
Loss: 0.2827
MSE: 0.2190
L2: 0.063762
Weight: [[2.3699088]]
Bias: [1.0015979]
Epoch 30:
Loss: 0.1478
MSE: 0.0655
L2: 0.082280
Weight: [[2.7073638]]
Bias: [1.0115404]
Epoch 40:
Loss: 0.1199
MSE: 0.0284
L2: 0.091504
Weight: [[2.8609738]]
Bias: [1.0125817]
Epoch 50:
Loss: 0.1141
MSE: 0.0182
L2: 0.095877
Weight: [[2.9308972]]
Bias: [1.0126907]
Training completed!
Final weight: [[2.9308972]] (true: 3.0)
Final bias: [1.0126907] (true: 1.0)
Summary#
In this tutorial, we covered:
argnums: Specify which function arguments to differentiate (inherited from JAX)grad_states: Specify whichStateobjects should receive gradients (BrainState extension)ParamState: Standard way to mark trainable parameters in modulesRetrieving states: Use
module.states()orbrainstate.graph.treefy_states()StateFinder: Discover states used in arbitrary functionsReturn structures: How
has_auxandreturn_valueaffect the outputOther transforms:
vector_grad,jacrev,jacfwd,jacobian,hessianCustom transforms: Build your own using
GradientTransform
Key Takeaways#
All gradient transformations share the same signature and return structure patterns
ParamStateis the standard for trainable parameters, but gradients work with anyStateStateFinderhelps discover states in arbitrary functionsGradientTransformenables custom gradient transformations while maintaining state-awarenessThe system seamlessly integrates JAX’s autodiff with BrainState’s stateful computation model