Training and Metrics#

This tutorial assembles the pieces you have met so far — Module, ParamState, grad, and jit — into a complete training loop, and introduces the tools BrainState provides around that loop: parameter counting, gradient clipping, optimizers, and metric tracking.

We use a small synthetic classification problem so the notebook runs in seconds with no downloads. The mechanics are identical for real datasets — only the data source changes.

You will learn to:

  • Inspect a model with count_parameters.

  • Drive parameter updates with a braintools.optim optimizer.

  • Stabilise training with clip_grad_norm.

  • Accumulate evaluation statistics with the MultiMetric system.

import jax.numpy as jnp

import brainstate
import braintools
from brainstate.nn import count_parameters, clip_grad_norm, MultiMetric, AverageMetric, AccuracyMetric
from braintools.metric import softmax_cross_entropy_with_integer_labels

brainstate.random.seed(42)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'

A self-contained dataset#

Our task is to classify points drawn from three Gaussian clusters in an 8-dimensional space. make_blobs samples n_per points around each class centre using brainstate.random so the data is reproducible under the seed set above.

DIM, N_CLASSES = 8, 3
centers = brainstate.random.randn(N_CLASSES, DIM) * 2.0

def make_blobs(n_per):
    xs = [brainstate.random.randn(n_per, DIM) + centers[c] for c in range(N_CLASSES)]
    ys = [jnp.full((n_per,), c, dtype=jnp.int32) for c in range(N_CLASSES)]
    return jnp.concatenate(xs), jnp.concatenate(ys)

x_train, y_train = make_blobs(200)
x_test, y_test = make_blobs(50)
x_train.shape, y_train.shape
((600, 8), (600,))

Neural networks train far more reliably on inputs with comparable scales across features, so we standardise the data to zero mean and unit variance. The statistics come from the training set only — the test set is transformed with those same numbers, never its own, to avoid leaking test information into preprocessing.

mean, std = x_train.mean(axis=0), x_train.std(axis=0)
x_train = (x_train - mean) / std
x_test = (x_test - mean) / std

Defining the model#

A two-layer perceptron is enough for this problem. We compose two brainstate.nn.Linear layers with a ReLU nonlinearity; the final layer emits one logit per class.

class MLP(brainstate.nn.Module):
    def __init__(self, din, dhidden, dout):
        super().__init__()
        self.fc1 = brainstate.nn.Linear(din, dhidden)
        self.fc2 = brainstate.nn.Linear(dhidden, dout)

    def __call__(self, x):
        return self.fc2(brainstate.nn.relu(self.fc1(x)))

model = MLP(DIM, 32, N_CLASSES)
model
MLP(
  fc1=Linear(
    in_size=(8,),
    out_size=(32,),
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[32]),
        'weight': ShapedArray(float32[8,32])
      }
    )
  ),
  fc2=Linear(
    in_size=(32,),
    out_size=(3,),
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[3]),
        'weight': ShapedArray(float32[32,3])
      }
    )
  )
)

Counting parameters#

Before training, it is worth knowing how large a model is. count_parameters walks the module tree, sums the sizes of every ParamState, and (by default) prints a per-parameter breakdown.

n_params = count_parameters(model)
n_params
+-------------------+------------+
|      Modules      | Parameters |
+-------------------+------------+
| ('fc1', 'weight') |    288     |
| ('fc2', 'weight') |     99     |
|       Total       |    387     |
+-------------------+------------+
387

Pass return_table=True to capture the formatted table as a string instead of printing it — useful for logging or for embedding in a report.

The optimizer#

braintools.optim provides the usual family of optimizers (SGD, Adam, AdamW, Lion, …). After constructing one, register the states it is allowed to update. We collect exactly the ParamState instances — any other state types (counters, running statistics) are left untouched.

optimizer = braintools.optim.Adam(lr=1e-2)
optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Adam(
  betas=(0.9, 0.999),
  eps=1e-08,
  amsgrad=False,
  param_states=<braintools.optim.UniqueStateManager object at 0x778921c5c2f0>,
  weight_decay=0.0,
  step_count=OptimState(
    value=ShapedArray(int32[], weak_type=True)
  ),
  param_groups=[
    {
      'params': {
        ('fc1', 'weight'): ParamState(
          value={
            'bias': ShapedArray(float32[32]),
            'weight': ShapedArray(float32[8,32])
          }
        ),
        ('fc2', 'weight'): ParamState(
          value={
            'bias': ShapedArray(float32[3]),
            'weight': ShapedArray(float32[32,3])
          }
        )
      },
      'lr': OptimState(
        value=ShapedArray(float32[], weak_type=True)
      ),
      'weight_decay': 0.0
    }
  ],
  param_groups_opt_states=[],
  _schedulers=[],
  _lr_scheduler=<braintools.optim.ConstantLR object at 0x778921c5cec0>,
  _base_lr=0.01,
  _current_lr=OptimState(...),
  tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x778921c82200>, update=<function chain.<locals>.update_fn at 0x778921c822a0>),
  opt_state=OptimState(
    value=(ScaleByAdamState(count=ShapedArray(int32[]), mu={('fc1', 'weight'): {'bias': ShapedArray(float32[32]), 'weight': ShapedArray(float32[8,32])}, ('fc2', 'weight'): {'bias': ShapedArray(float32[3]), 'weight': ShapedArray(float32[32,3])}}, nu={('fc1', 'weight'): {'bias': ShapedArray(float32[32]), 'weight': ShapedArray(float32[8,32])}, ('fc2', 'weight'): {'bias': ShapedArray(float32[3]), 'weight': ShapedArray(float32[32,3])}}), ScaleByScheduleState(count=ShapedArray(int32[])))
  )
)

The training step#

A single training step does three things:

  1. grad differentiates the loss with respect to the registered parameters. With return_value=True it returns (grads, loss) in one pass, so we get the loss for free.

  2. clip_grad_norm rescales the gradients so their global norm never exceeds max_norm. This is cheap insurance against the occasional exploding update.

  3. optimizer.update(grads) applies the rescaled gradients in place.

Wrapping the whole step in brainstate.transform.jit compiles it once and reuses the compiled version on every call. State reads and writes are tracked automatically across the transform boundary — there is no manual parameter threading.

params = model.states(brainstate.ParamState)

@brainstate.transform.jit
def train_step(x, y):
    def loss_fn():
        logits = model(x)
        return softmax_cross_entropy_with_integer_labels(logits, y).mean()

    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
    grads = clip_grad_norm(grads, max_norm=1.0)
    optimizer.update(grads)
    return loss

Tracking metrics#

MultiMetric bundles several metrics that are updated together. Each sub-metric reads the keyword arguments it needs from a single update(...) call: AverageMetric('loss') consumes loss=..., while AccuracyMetric consumes logits=... and labels=.... The lifecycle is always reset → update (per batch) → compute.

metrics = MultiMetric(
    loss=AverageMetric('loss'),
    accuracy=AccuracyMetric(),
)

@brainstate.transform.jit
def eval_step(x, y):
    logits = model(x)
    loss = softmax_cross_entropy_with_integer_labels(logits, y).mean()
    metrics.update(loss=loss, logits=logits, labels=y)

The training loop#

We iterate in mini-batches, reshuffling each epoch with brainstate.random.permutation. After every epoch we reset the metrics, run the evaluation step over the test set, and read back the accumulated statistics with compute.

def iter_batches(x, y, batch_size):
    order = brainstate.random.permutation(len(x))
    for i in range(0, len(x), batch_size):
        idx = order[i:i + batch_size]
        yield x[idx], y[idx]

for epoch in range(15):
    for xb, yb in iter_batches(x_train, y_train, batch_size=32):
        train_step(xb, yb)

    metrics.reset()
    eval_step(x_test, y_test)
    stats = metrics.compute()
    if epoch % 3 == 0 or epoch == 14:
        print(f"epoch {epoch:2d} | test loss {float(stats['loss']):.4f} | "
              f"test acc {float(stats['accuracy']):.3f}")
epoch  0 | test loss 0.1440 | test acc 0.973
epoch  3 | test loss 0.0471 | test acc 0.980
epoch  6 | test loss 0.0411 | test acc 0.980
epoch  9 | test loss 0.0355 | test acc 0.987
epoch 12 | test loss 0.0367 | test acc 0.987
epoch 14 | test loss 0.0394 | test acc 0.987

Evaluating#

metrics.compute() returns a plain dictionary, so the final numbers are ordinary arrays you can log, compare, or assert on.

final = metrics.compute()
print(f"final test accuracy: {float(final['accuracy']):.3f}")
final test accuracy: 0.987

Summary#

A BrainState training loop is built from four composable pieces:

  • count_parameters — inspect model size before you commit to training.

  • braintools.optim — optimizers that update registered ParamStates in place.

  • grad + clip_grad_norm + jit — a compiled step that differentiates, stabilises, and applies updates with no manual state bookkeeping.

  • MultiMetric — reset/update/compute accumulation for evaluation statistics.

See also#