Automatic Differentiation#

BrainState provides a comprehensive automatic differentiation system built on top of JAX, designed specifically for stateful computations. This tutorial focuses on brainstate.transform.grad and related gradient transformations, demonstrating how to compute gradients with respect to function arguments and State objects.

Key Concepts#

BrainState’s gradient system revolves around two key concepts:

  1. argnums: Select which function arguments to differentiate with respect to (inherited from JAX)

  2. grad_states: Select which State objects should receive gradients (BrainState’s extension)

Additionally, BrainState uses ParamState to mark trainable parameters in neural networks and provides utilities to discover and manage states in arbitrary functions.

import jax
import jax.numpy as jnp

import brainstate
from brainstate.transform import grad, StateFinder

1. Understanding argnums: Gradients w.r.t. Function Arguments#

The argnums parameter works just like in JAX’s jax.grad. It specifies which positional arguments to differentiate with respect to.

def loss_fn(x, y, scale):
    """Simple loss function with multiple arguments."""
    return scale * jnp.sum((x - y) ** 2)

x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([0.5, 1.5, 2.5])
scale = 2.0

# Gradient w.r.t. the first argument (x)
grad_fn_x = grad(loss_fn, argnums=0)
grad_x = grad_fn_x(x, y, scale)
print("Gradient w.r.t. x:", grad_x)

# Gradient w.r.t. multiple arguments
grad_fn_xy = grad(loss_fn, argnums=[0, 1])
grad_x, grad_y = grad_fn_xy(x, y, scale)
print("Gradient w.r.t. x:", grad_x)
print("Gradient w.r.t. y:", grad_y)
Gradient w.r.t. x: [2. 2. 2.]
Gradient w.r.t. x: [2. 2. 2.]
Gradient w.r.t. y: [-2. -2. -2.]

2. Understanding grad_states: Gradients w.r.t. State Objects#

2.1 ParamState for Trainable Parameters#

In BrainState, ParamState is used to mark parameters that should receive gradients during training. This is the standard way to define trainable parameters in neural network modules.

class LinearRegressor(brainstate.nn.Module):
    """Simple linear regression model."""
    
    def __init__(self, in_features: int, out_features: int = 1):
        super().__init__()
        # ParamState marks these as trainable parameters
        self.weight = brainstate.ParamState(jnp.zeros((in_features, out_features)))
        self.bias = brainstate.ParamState(jnp.zeros((out_features,)))

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.weight.value + self.bias.value


# Create model and training data
model = LinearRegressor(1)
xs = jnp.linspace(-1.0, 1.0, 5).reshape(-1, 1)
y_true = 3.0 * xs + 1.0


def mse_loss(x: jax.Array, target: jax.Array) -> jax.Array:
    """Mean squared error loss."""
    pred = model(x)
    return jnp.mean((pred - target) ** 2)


# Compute gradients w.r.t. model parameters
loss_grad = grad(
    mse_loss,
    grad_states=model.states(brainstate.ParamState),  # Get all ParamState instances
    return_value=True,
)

param_grads, loss_value = loss_grad(xs, y_true)
print(f"Loss: {float(loss_value):.4f}")
print("\nParameter gradients:")
for path, g in param_grads.items():
    print(f"  {path}: {g}")
Loss: 5.5000

Parameter gradients:
  ('bias',): [-2.]
  ('weight',): [[-3.]]

2.2 Retrieving States from Modules#

BrainState provides two main ways to retrieve states from modules:

  1. module.states(*filter): Get states directly from a Module instance

  2. brainstate.graph.treefy_states(node, *filter): Get states from any object (more general)

# Method 1: Using module.states()
params_method1 = model.states(brainstate.ParamState)
print("Using model.states():")
for path, state in params_method1.items():
    print(f"  {path}: shape={state.value.shape}")

# Method 2: Using brainstate.graph.treefy_states()
params_method2 = brainstate.graph.treefy_states(model, brainstate.ParamState)
print("\nUsing brainstate.graph.treefy_states():")
for path, state in params_method2.to_flat().items():
    print(f"  {path}: shape={state.value.shape}")

# Both methods return the same states
assert set(params_method1.keys()) == set(params_method2.to_flat().keys())
Using model.states():
  ('bias',): shape=(1,)
  ('weight',): shape=(1, 1)

Using brainstate.graph.treefy_states():
  ('bias',): shape=(1,)
  ('weight',): shape=(1, 1)

2.3 Using StateFinder for Arbitrary Functions#

Not every function is a nn.Module. For arbitrary functions, you can use StateFinder to discover which states are used inside the function.

# Create some standalone states
scale = brainstate.ParamState(jnp.array(1.5), name="scale")
offset = brainstate.ParamState(jnp.array(-0.2), name="offset")
cache = brainstate.State(jnp.array(0.0), name="cache")  # Not a ParamState


def energy(x: jax.Array) -> jax.Array:
    """Energy function using external states."""
    shifted = x * scale.value + offset.value
    # Update a state to track it as a write operation
    scale.value = scale.value + 0.0  # Dummy update to mark as written
    cache.value = jnp.sum(shifted)  # Write to cache
    return jnp.sum(jnp.square(shifted))


# Use StateFinder to discover states used in the function
finder = StateFinder(
    energy,
    filter=brainstate.ParamState,  # Only find ParamState instances
    usage='all',  # Find both read and write states
    return_type='dict',  # Return as a dictionary
)

all_param_states = finder(jnp.ones((2,)))
print("States found by StateFinder:")
for name, state in all_param_states.items():
    print(f"  {name}: {state.name}")

# Now compute gradients w.r.t. these discovered states
energy_grad = grad(
    energy,
    grad_states=all_param_states,
    return_value=True,
)

state_grads, energy_value = energy_grad(jnp.array([1.0, 3.0]))
print(f"\nEnergy: {float(energy_value):.4f}")
print("Gradients:")
for idx, key in enumerate(state_grads):
    st = all_param_states[key]
    print(f"  {key}: {state_grads[key]}")
States found by StateFinder:
  0: scale
  1: offset

Energy: 20.1800
Gradients:
  0: 28.400001525878906
  1: 11.200000762939453

2.4 Important Note: Gradients are Not Limited to ParamState#

While ParamState is the standard way to mark trainable parameters, gradient computation works with any State instance. You can compute gradients w.r.t. any State object.

# Create a regular State (not ParamState)
regular_state = brainstate.State(jnp.array(2.0), name="regular_state")


def compute_with_state(x):
    return jnp.sum((x * regular_state.value) ** 2)


# Compute gradient w.r.t. regular State
grad_fn = grad(compute_with_state, grad_states=[regular_state])
gradient = grad_fn(jnp.array([1.0, 2.0, 3.0]))
print(f"Gradient w.r.t. regular State: {gradient[0]}")
Gradient w.r.t. regular State: 56.0

3. Combining argnums and grad_states#

You can compute gradients with respect to both function arguments and states simultaneously.

reg_model = LinearRegressor(1)


def penalized_loss(l2_coeff: float, inputs: jax.Array, target: jax.Array) -> jax.Array:
    """Loss with L2 regularization."""
    pred = reg_model(inputs)
    mse = jnp.mean((pred - target) ** 2)
    # L2 penalty on parameters
    l2 = jnp.sum(reg_model.weight.value ** 2) + jnp.sum(reg_model.bias.value ** 2)
    return mse + l2_coeff * l2


# Compute gradients w.r.t. both states and the first argument
grad_penalized = grad(
    penalized_loss,
    grad_states=reg_model.states(brainstate.ParamState),
    argnums=0,  # Also differentiate w.r.t. l2_coeff
    return_value=True,
)

(state_grads, coeff_grad), loss_val = grad_penalized(0.5, xs, y_true)
print(f"Loss: {float(loss_val):.4f}")
print(f"Gradient w.r.t. l2_coeff: {float(coeff_grad):.4f}")
print("\nState gradients:")
for path, g in state_grads.items():
    print(f"  {path}: {g}")
Loss: 5.5000
Gradient w.r.t. l2_coeff: 0.0000

State gradients:
  ('bias',): [-2.]
  ('weight',): [[-3.]]

4. Return Value Structures#

All gradient transformations in BrainState share a common signature pattern. The return structure depends on the combination of grad_states, argnums, has_aux, and return_value.

4.1 Basic Return Structures#

When grad_states is None:

  • has_aux=False + return_value=Falsearg_grads

  • has_aux=True + return_value=False(arg_grads, aux_data)

  • has_aux=False + return_value=True(arg_grads, loss_value)

  • has_aux=True + return_value=True(arg_grads, loss_value, aux_data)

When grad_states is not None and argnums is None:

  • has_aux=False + return_value=Falsevar_grads

  • has_aux=True + return_value=False(var_grads, aux_data)

  • has_aux=False + return_value=True(var_grads, loss_value)

  • has_aux=True + return_value=True(var_grads, loss_value, aux_data)

When both grad_states and argnums are not None:

  • has_aux=False + return_value=False(var_grads, arg_grads)

  • has_aux=True + return_value=False((var_grads, arg_grads), aux_data)

  • has_aux=False + return_value=True((var_grads, arg_grads), loss_value)

  • has_aux=True + return_value=True((var_grads, arg_grads), loss_value, aux_data)

List them as a table for clarity:

grad_states

argnums

has_aux

return_value

result

None

any

False

False

arg_grads

None

any

True

False

(arg_grads, aux)

None

any

False

True

(arg_grads, loss)

None

any

True

True

(arg_grads, loss, aux)

not None

None

False

False

var_grads

not None

None

True

False

(var_grads, aux)

not None

None

False

True

(var_grads, loss)

not None

None

True

True

(var_grads, loss, aux)

not None

not None

False

False

(var_grads, arg_grads)

not None

not None

True

False

((var_grads, arg_grads), aux)

not None

not None

False

True

((var_grads, arg_grads), loss)

not None

not None

True

True

((var_grads, arg_grads), loss, aux)

4.2 Complete Example: All Return Options#

example_model = LinearRegressor(1)


def loss_with_metrics(l2_coeff: float, x: jax.Array, target: jax.Array):
    """Loss function that returns auxiliary metrics."""
    pred = example_model(x)
    mse = jnp.mean((pred - target) ** 2)
    l2 = jnp.sum(example_model.weight.value ** 2)
    loss = mse + l2_coeff * l2
    
    # Return loss and auxiliary metrics
    metrics = {
        "mae": jnp.mean(jnp.abs(pred - target)),
        "mse": mse,
        "l2": l2,
    }
    return loss, metrics


# Example: grad_states + argnums + has_aux + return_value
grad_complete = grad(
    loss_with_metrics,
    grad_states=example_model.states(brainstate.ParamState),
    argnums=0,
    has_aux=True,
    return_value=True,
)

((state_grads, coeff_grad), loss_val, aux_metrics) = grad_complete(0.3, xs, y_true)

print(f"Loss: {float(loss_val):.4f}")
print(f"\nGradient w.r.t. l2_coeff: {float(coeff_grad):.4f}")
print("\nState gradients:")
for path, g in state_grads.items():
    print(f"  {path}: {g}")
print("\nAuxiliary metrics:")
for key, val in aux_metrics.items():
    print(f"  {key}: {float(val):.4f}")
Loss: 5.5000

Gradient w.r.t. l2_coeff: 0.0000

State gradients:
  ('bias',): [-2.]
  ('weight',): [[-3.]]

Auxiliary metrics:
  l2: 0.0000
  mae: 2.0000
  mse: 5.5000

5. Other Gradient Transformations#

BrainState provides several other gradient transformations, all sharing the same signature pattern as grad.

5.1 Vector Gradient#

vector_grad is used for vector-valued functions. It computes the sum of gradients across all output dimensions.

from brainstate.transform import vector_grad


def vector_fun(x):
    """Vector-valued function."""
    return jnp.array([x[0] * x[1], jnp.sin(x[0]), x[0]**2 + x[1]**2])


x0 = jnp.array([1.0, 2.0])

# Vector gradient sums gradients across all outputs
vgrad = vector_grad(vector_fun)
result = vgrad(x0)
print("Vector gradient:", result)
Vector gradient: [4.5403023 5.       ]

5.2 Jacobians: jacrev and jacfwd#

  • jacrev: Jacobian using reverse-mode autodiff (efficient for many inputs, few outputs)

  • jacfwd: Jacobian using forward-mode autodiff (efficient for few inputs, many outputs)

  • jacobian: Alias for jacrev

from brainstate.transform import jacrev, jacfwd, jacobian


def multi_output(x):
    """Function with multiple outputs."""
    return jnp.array([x[0] * x[1], jnp.sin(x[0]), jnp.exp(x[1])])


x0 = jnp.array([1.0, 2.0])

# Reverse-mode Jacobian
jac_rev = jacrev(multi_output)
result_rev = jac_rev(x0)
print("Jacobian (reverse-mode):")
print(result_rev)

# Forward-mode Jacobian
jac_fwd = jacfwd(multi_output)
result_fwd = jac_fwd(x0)
print("\nJacobian (forward-mode):")
print(result_fwd)

# They should be the same
assert jnp.allclose(result_rev, result_fwd)

# jacobian is an alias for jacrev
jac_alias = jacobian(multi_output)
result_alias = jac_alias(x0)
assert jnp.allclose(result_rev, result_alias)
Jacobian (reverse-mode):
[[2.        1.       ]
 [0.5403023 0.       ]
 [0.        7.389056 ]]

Jacobian (forward-mode):
[[2.        1.       ]
 [0.5403023 0.       ]
 [0.        7.389056 ]]

5.3 Hessian#

hessian computes second-order derivatives.

from brainstate.transform import hessian


def quadratic(x):
    """Quadratic function."""
    return jnp.dot(x, x)


x0 = jnp.array([1.0, 2.0])

hess_fn = hessian(quadratic)
result = hess_fn(x0)
print("Hessian:")
print(result)

# For a quadratic form x^T x, the Hessian is 2*I
expected = 2 * jnp.eye(2)
assert jnp.allclose(result, expected)
Hessian:
[[2. 0.]
 [0. 2.]]

5.4 Using Gradient Transformations with States#

# Example: Jacobian with states
jac_model = LinearRegressor(2)


def model_output(x):
    """Multiple outputs from a model."""
    return jac_model(x)


# Compute Jacobian w.r.t. model parameters
jac_states = jacrev(
    model_output,
    grad_states=jac_model.states(brainstate.ParamState)
)

x_input = jnp.array([1.0, 2.0])
param_jacobian = jac_states(x_input)

print("Jacobian w.r.t. parameters:")
for path, jac in param_jacobian.items():
    print(f"  {path}: shape={jac.shape}")
Jacobian w.r.t. parameters:
  ('bias',): shape=(1, 1)
  ('weight',): shape=(1, 2, 1)

6. Custom Gradient Transformations with GradientTransform#

You can create custom gradient transformations by using the GradientTransform class. This allows you to wrap any JAX gradient function while maintaining BrainState’s state-aware behavior.

6.1 Basic Custom Transform#

from brainstate.transform import GradientTransform


def scaled_grad_transform(fun, *, argnums, has_aux, scale):
    """Custom gradient transform that scales gradients."""
    # Use JAX's grad as the base transformation
    base = jax.grad(fun, argnums=argnums, has_aux=True)

    def wrapped(*args, **kwargs):
        grads, aux = base(*args, **kwargs)
        # Scale all gradients
        grads = jax.tree.map(lambda g: scale * g, grads)
        return grads, aux

    return wrapped


def scaled_grad(
    fun,
    *,
    scale=1.0,
    grad_states=None,
    argnums=None,
    has_aux=False,
    return_value=False,
):
    """Create a gradient function with scaled gradients."""
    return GradientTransform(
        fun,
        transform=scaled_grad_transform,
        grad_states=grad_states,
        argnums=argnums,
        has_aux=has_aux,
        return_value=return_value,
        transform_params={"scale": scale},  # Pass custom parameters
    )


# Example usage
custom_model = LinearRegressor(1)


def custom_loss(x, target):
    pred = custom_model(x)
    return jnp.mean((pred - target) ** 2)


# Use custom scaled gradient
scaled_grad_fn = scaled_grad(
    custom_loss,
    scale=0.5,  # Scale gradients by 0.5
    grad_states=custom_model.states(brainstate.ParamState),
)

scaled_grads = scaled_grad_fn(xs, y_true)
print("Scaled gradients:")
for path, g in scaled_grads.items():
    print(f"  {path}: {g}")

# Compare with unscaled gradients
normal_grad_fn = grad(custom_loss, grad_states=custom_model.states(brainstate.ParamState))
normal_grads = normal_grad_fn(xs, y_true)
print("\nNormal gradients:")
for path, g in normal_grads.items():
    print(f"  {path}: {g}")
Scaled gradients:
  ('bias',): [-1.]
  ('weight',): [[-1.5]]

Normal gradients:
  ('bias',): [-2.]
  ('weight',): [[-3.]]

6.2 Advanced: Gradient Clipping Transform#

def clipped_grad_transform(fun, *, argnums, has_aux, max_norm):
    """Custom gradient transform with gradient clipping."""
    base = jax.grad(fun, argnums=argnums, has_aux=True)

    def wrapped(*args, **kwargs):
        grads, aux = base(*args, **kwargs)
        
        # Compute global norm
        global_norm = jnp.sqrt(
            sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads))
        )
        
        # Clip gradients
        scale = jnp.minimum(1.0, max_norm / (global_norm + 1e-6))
        grads = jax.tree.map(lambda g: scale * g, grads)
        
        return grads, aux

    return wrapped


def clipped_grad(
    fun,
    *,
    max_norm=1.0,
    grad_states=None,
    argnums=None,
    has_aux=False,
    return_value=False,
):
    """Create a gradient function with gradient clipping."""
    return GradientTransform(
        fun,
        transform=clipped_grad_transform,
        grad_states=grad_states,
        argnums=argnums,
        has_aux=has_aux,
        return_value=return_value,
        transform_params={"max_norm": max_norm},
    )


# Example: gradient clipping
clip_model = LinearRegressor(1)


def clip_loss(x, target):
    pred = clip_model(x)
    return jnp.mean((pred - target) ** 2)


clipped_grad_fn = clipped_grad(
    clip_loss,
    max_norm=0.1,  # Clip gradients to max norm of 0.1
    grad_states=clip_model.states(brainstate.ParamState),
)

clipped_grads = clipped_grad_fn(xs, y_true)
print("Clipped gradients:")
for path, g in clipped_grads.items():
    print(f"  {path}: {g}")
    print(f"    norm: {jnp.linalg.norm(g):.4f}")
Clipped gradients:
  ('bias',): [-0.05547]
    norm: 0.0555
  ('weight',): [[-0.08320501]]
    norm: 0.0832

7. Practical Example: Training Loop#

Let’s put everything together in a complete training example.

# Create a fresh model
training_model = LinearRegressor(1)

# Generate training data
true_weight = 3.0
true_bias = 1.0
x_train = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
y_train = true_weight * x_train + true_bias + 0.1 * brainstate.random.normal(size=x_train.shape)


@brainstate.transform.jit
def training_loss(x, y):
    """MSE loss with L2 regularization."""
    pred = training_model(x)
    mse = jnp.mean((pred - y) ** 2)
    l2 = 0.01 * (jnp.sum(training_model.weight.value ** 2) + jnp.sum(training_model.bias.value ** 2))
    return mse + l2, {"mse": mse, "l2": l2}


# Create gradient function
loss_grad_fn = grad(
    training_loss,
    grad_states=training_model.states(brainstate.ParamState),
    has_aux=True,
    return_value=True,
)

# Training loop
learning_rate = 0.1
num_epochs = 50

print("Training started...")
print(f"Initial weight: {training_model.weight.value}")
print(f"Initial bias: {training_model.bias.value}")

for epoch in range(num_epochs):
    # Compute gradients
    grads, loss_val, aux = loss_grad_fn(x_train, y_train)
    
    # Update parameters (simple SGD)
    for path, state in training_model.states(brainstate.ParamState).items():
        grad = grads[path]
        state.value = state.value - learning_rate * grad
    
    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"\nEpoch {epoch + 1}:")
        print(f"  Loss: {float(loss_val):.4f}")
        print(f"  MSE: {float(aux['mse']):.4f}")
        print(f"  L2: {float(aux['l2']):.6f}")
        print(f"  Weight: {training_model.weight.value}")
        print(f"  Bias: {training_model.bias.value}")

print("\nTraining completed!")
print(f"Final weight: {training_model.weight.value} (true: {true_weight})")
print(f"Final bias: {training_model.bias.value} (true: {true_bias})")
Training started...
Initial weight: [[0.]]
Initial bias: [0.]

Epoch 10:
  Loss: 0.9506
  MSE: 0.9198
  L2: 0.030758
  Weight: [[1.6285777]]
  Bias: [0.9066533]

Epoch 20:
  Loss: 0.2827
  MSE: 0.2190
  L2: 0.063762
  Weight: [[2.3699088]]
  Bias: [1.0015979]

Epoch 30:
  Loss: 0.1478
  MSE: 0.0655
  L2: 0.082280
  Weight: [[2.7073638]]
  Bias: [1.0115404]

Epoch 40:
  Loss: 0.1199
  MSE: 0.0284
  L2: 0.091504
  Weight: [[2.8609738]]
  Bias: [1.0125817]

Epoch 50:
  Loss: 0.1141
  MSE: 0.0182
  L2: 0.095877
  Weight: [[2.9308972]]
  Bias: [1.0126907]

Training completed!
Final weight: [[2.9308972]] (true: 3.0)
Final bias: [1.0126907] (true: 1.0)

Summary#

In this tutorial, we covered:

  1. argnums: Specify which function arguments to differentiate (inherited from JAX)

  2. grad_states: Specify which State objects should receive gradients (BrainState extension)

  3. ParamState: Standard way to mark trainable parameters in modules

  4. Retrieving states: Use module.states() or brainstate.graph.treefy_states()

  5. StateFinder: Discover states used in arbitrary functions

  6. Return structures: How has_aux and return_value affect the output

  7. Other transforms: vector_grad, jacrev, jacfwd, jacobian, hessian

  8. Custom transforms: Build your own using GradientTransform

Key Takeaways#

  • All gradient transformations share the same signature and return structure patterns

  • ParamState is the standard for trainable parameters, but gradients work with any State

  • StateFinder helps discover states in arbitrary functions

  • GradientTransform enables custom gradient transformations while maintaining state-awareness

  • The system seamlessly integrates JAX’s autodiff with BrainState’s stateful computation model