Migrating Concepts from PyTorch to BrainState#

BrainState borrows many familiar ideas from PyTorch—tensor computations, modules with parameters, automatic differentiation—while leaning on JAX for JIT compilation and functional programming. This note contrasts the key building blocks so you can translate existing PyTorch workflows into BrainState idioms quickly.

Concept map#

PyTorch

BrainState

Notes

torch.Tensor

jax.Array (jnp.ndarray)

Manipulated with jax.numpy semantics.

nn.Module

brainstate.nn.Module

Define State attributes (e.g. ParamState, HiddenState).

nn.Parameter

brainstate.ParamState

Holds differentiable weights; retrieved via .states.

autograd.grad / backward()

brainstate.transform.grad

Explicitly select which states or arguments receive gradients.

torch.optim optimisers

braintools.optim (optional)

Works on .states(brainstate.ParamState).

torch.jit.script / torch.jit.trace

brainstate.transform.jit

JIT compile pure or stateful functions; integrates with JAX.

state_dict() / load_state_dict()

brainstate.graph.treefy_states / brainstate.graph.update_states

Serialize/restore state trees.

Random number generators (torch.manual_seed)

brainstate.random.seed / RandomState

Keys are JAX PRNGs, automatically split in transforms.

PyTorch baseline#

Consider a minimal linear regression in PyTorch:

import torch
import torch.nn as nn
import torch.optim as optim

class TorchLinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

model = TorchLinear()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1)

for step in range(100):
    optimizer.zero_grad()
    preds = model(inputs)
    loss = criterion(preds, targets)
    loss.backward()
    optimizer.step()

BrainState follows the same logic but makes states and gradients explicit.

BrainState equivalent#

import braintools.file
import jax
import jax.numpy as jnp
import numpy as np

import brainstate
from brainstate.transform import grad, jit
import braintools.optim as optim

# Synthetic dataset
def make_dataset(n=64):
    rng = np.random.default_rng(0)
    x = rng.uniform(-1.0, 1.0, (n, 1)).astype(np.float32)
    y = 3.0 * x + 1.0 + rng.normal(0.0, 0.1, (n, 1)).astype(np.float32)
    return jnp.asarray(x), jnp.asarray(y)

x_train, y_train = make_dataset()

class LinearModel(brainstate.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        k1, k2 = jax.random.split(jax.random.PRNGKey(0))
        self.weight = brainstate.ParamState(jax.random.normal(k1, (in_features, out_features)))
        self.bias = brainstate.ParamState(jax.random.normal(k2, (out_features,)))

    def __call__(self, x):
        return x @ self.weight.value + self.bias.value

model = LinearModel(1, 1)
params = model.states(brainstate.ParamState)
optimizer = optim.SGD(lr=1e-1)
optimizer.register_trainable_weights(params)

@jit
def train_step(x, y):
    def loss_fn():
        preds = model(x)
        return jnp.mean((preds - y) ** 2)

    (grads, loss) = grad(loss_fn, grad_states=params, return_value=True)()
    optimizer.update(grads)
    return loss

for step in range(200):
    loss = train_step(x_train, y_train)
    if step % 40 == 0:
        print(f"step {step:3d}, loss = {float(loss):.4f}")

@jit
def predict(x):
    return model(x)

print('predictions for x=0:', predict(jnp.array([[0.0]])))
step   0, loss = 13.1320
step  40, loss = 0.0144
step  80, loss = 0.0097
step 120, loss = 0.0097
step 160, loss = 0.0097
predictions for x=0: [[1.0059681]]

Key observations#

  • Parameters are stored in ParamState objects, so gradients are a tree keyed by state paths (params.to_flat() mirrors state_dict()).

  • grad explicitly lists grad_states; argument gradients can be included via argnums (similar to PyTorch’s manual requires_grad).

  • Optimisers work on state trees instead of implicit parameter lists.

Saving and loading state#

state_tree = brainstate.graph.treefy_states(model)
print('stored keys:', list(state_tree.to_flat().keys()))

# Later (or in another process):
restored = LinearModel(1, 1)
brainstate.graph.update_states(restored, state_tree)
print('restored weight:', restored.weight.value)
stored keys: [('bias',), ('weight',)]
restored weight: [[3.0168793]]

Alternatively, you can use braintools.file.msgpack_save and braintools.file.msgpack_load.

braintools.file.msgpack_save('example.msgpack', model.states(brainstate.ParamState))
Saving checkpoint into example.msgpack

Gradients with additional arguments#

Below, we take derivatives w.r.t. both model parameters and an explicit scalar.

scale = jnp.array(0.1)

def scaled_loss(alpha, inputs, targets):
    preds = model(inputs)
    mse = jnp.mean((preds - targets) ** 2)
    return mse + alpha * jnp.sum(model.weight.value ** 2)

(grads_state, alpha_grad), loss_val = grad(
    scaled_loss,
    grad_states=params,
    argnums=0,
    return_value=True,
)(scale, x_train, y_train)

print('loss:', float(loss_val))
print('grad w.r.t alpha:', float(alpha_grad))
for path, g in grads_state.items():
    print(path, g.shape)
loss: 0.9198333024978638
grad w.r.t alpha: 9.101560592651367
('bias',) (1,)
('weight',) (1, 1)

Random numbers#

BrainState wraps JAX PRNG keys. Use brainstate.random.seed to set the global seed, or instantiate a RandomState for module-specific randomness. Transforms like vmap and pmap split keys automatically per batch element.

import brainstate.random as brandom

brandom.seed(42)
rs = brandom.RandomState()
print('single sample:', rs.normal(size=(2,)))
single sample: [ 0.6630465  -0.72396195]

Debugging and JIT#

BrainState leans on JAX’s tooling. brainstate.transform.jit works on stateful functions, while brainstate.transform.make_jaxpr inspects the computed graph.

from brainstate.transform import make_jaxpr

jaxpr = make_jaxpr(model)
print(jaxpr(jnp.ones((1,))))
({ lambda ; a:f32[1] b:f32[1,1] c:f32[1]. let
    d:f32[1] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] a b
    e:f32[1] = add d c
  in (e, b, c) }, (ParamState(
  value=ShapedArray(float32[1,1])
), ParamState(
  value=ShapedArray(float32[1])
)))

Summary#

  • Replace nn.Module + nn.Parameter with brainstate.nn.Module + ParamState.

  • Use brainstate.transform.grad/jit instead of PyTorch autograd and scripting.

  • Retrieve and update parameter trees via graph.treefy_states and graph.update_states.

  • Optimisers in braintools.optim mirror the familiar PyTorch API, operating on state dictionaries.

With these substitutions most PyTorch training loops can be ported one module at a time to BrainState.