Migrating Concepts from PyTorch to BrainState#
BrainState borrows many familiar ideas from PyTorch—tensor computations, modules with parameters, automatic differentiation—while leaning on JAX for JIT compilation and functional programming. This note contrasts the key building blocks so you can translate existing PyTorch workflows into BrainState idioms quickly.
Concept map#
PyTorch |
BrainState |
Notes |
|---|---|---|
|
|
Manipulated with |
|
|
Define |
|
|
Holds differentiable weights; retrieved via |
|
|
Explicitly select which states or arguments receive gradients. |
|
|
Works on |
|
|
JIT compile pure or stateful functions; integrates with JAX. |
|
|
Serialize/restore state trees. |
Random number generators ( |
|
Keys are JAX PRNGs, automatically split in transforms. |
PyTorch baseline#
Consider a minimal linear regression in PyTorch:
import torch
import torch.nn as nn
import torch.optim as optim
class TorchLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
model = TorchLinear()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1)
for step in range(100):
optimizer.zero_grad()
preds = model(inputs)
loss = criterion(preds, targets)
loss.backward()
optimizer.step()
BrainState follows the same logic but makes states and gradients explicit.
BrainState equivalent#
import braintools.file
import jax
import jax.numpy as jnp
import numpy as np
import brainstate
from brainstate.transform import grad, jit
import braintools.optim as optim
# Synthetic dataset
def make_dataset(n=64):
rng = np.random.default_rng(0)
x = rng.uniform(-1.0, 1.0, (n, 1)).astype(np.float32)
y = 3.0 * x + 1.0 + rng.normal(0.0, 0.1, (n, 1)).astype(np.float32)
return jnp.asarray(x), jnp.asarray(y)
x_train, y_train = make_dataset()
class LinearModel(brainstate.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
self.weight = brainstate.ParamState(jax.random.normal(k1, (in_features, out_features)))
self.bias = brainstate.ParamState(jax.random.normal(k2, (out_features,)))
def __call__(self, x):
return x @ self.weight.value + self.bias.value
model = LinearModel(1, 1)
params = model.states(brainstate.ParamState)
optimizer = optim.SGD(lr=1e-1)
optimizer.register_trainable_weights(params)
@jit
def train_step(x, y):
def loss_fn():
preds = model(x)
return jnp.mean((preds - y) ** 2)
(grads, loss) = grad(loss_fn, grad_states=params, return_value=True)()
optimizer.update(grads)
return loss
for step in range(200):
loss = train_step(x_train, y_train)
if step % 40 == 0:
print(f"step {step:3d}, loss = {float(loss):.4f}")
@jit
def predict(x):
return model(x)
print('predictions for x=0:', predict(jnp.array([[0.0]])))
step 0, loss = 13.1320
step 40, loss = 0.0144
step 80, loss = 0.0097
step 120, loss = 0.0097
step 160, loss = 0.0097
predictions for x=0: [[1.0059681]]
Key observations#
Parameters are stored in
ParamStateobjects, so gradients are a tree keyed by state paths (params.to_flat()mirrorsstate_dict()).gradexplicitly listsgrad_states; argument gradients can be included viaargnums(similar to PyTorch’s manualrequires_grad).Optimisers work on state trees instead of implicit parameter lists.
Saving and loading state#
state_tree = brainstate.graph.treefy_states(model)
print('stored keys:', list(state_tree.to_flat().keys()))
# Later (or in another process):
restored = LinearModel(1, 1)
brainstate.graph.update_states(restored, state_tree)
print('restored weight:', restored.weight.value)
stored keys: [('bias',), ('weight',)]
restored weight: [[3.0168793]]
Alternatively, you can use braintools.file.msgpack_save and braintools.file.msgpack_load.
braintools.file.msgpack_save('example.msgpack', model.states(brainstate.ParamState))
Saving checkpoint into example.msgpack
Gradients with additional arguments#
Below, we take derivatives w.r.t. both model parameters and an explicit scalar.
scale = jnp.array(0.1)
def scaled_loss(alpha, inputs, targets):
preds = model(inputs)
mse = jnp.mean((preds - targets) ** 2)
return mse + alpha * jnp.sum(model.weight.value ** 2)
(grads_state, alpha_grad), loss_val = grad(
scaled_loss,
grad_states=params,
argnums=0,
return_value=True,
)(scale, x_train, y_train)
print('loss:', float(loss_val))
print('grad w.r.t alpha:', float(alpha_grad))
for path, g in grads_state.items():
print(path, g.shape)
loss: 0.9198333024978638
grad w.r.t alpha: 9.101560592651367
('bias',) (1,)
('weight',) (1, 1)
Random numbers#
BrainState wraps JAX PRNG keys. Use brainstate.random.seed to set the global
seed, or instantiate a RandomState for module-specific randomness. Transforms
like vmap and pmap split keys automatically per batch element.
import brainstate.random as brandom
brandom.seed(42)
rs = brandom.RandomState()
print('single sample:', rs.normal(size=(2,)))
single sample: [ 0.6630465 -0.72396195]
Debugging and JIT#
BrainState leans on JAX’s tooling. brainstate.transform.jit works on stateful
functions, while brainstate.transform.make_jaxpr inspects the computed graph.
from brainstate.transform import make_jaxpr
jaxpr = make_jaxpr(model)
print(jaxpr(jnp.ones((1,))))
({ lambda ; a:f32[1] b:f32[1,1] c:f32[1]. let
d:f32[1] = dot_general[
dimension_numbers=(([0], [0]), ([], []))
preferred_element_type=float32
] a b
e:f32[1] = add d c
in (e, b, c) }, (ParamState(
value=ShapedArray(float32[1,1])
), ParamState(
value=ShapedArray(float32[1])
)))
Summary#
Replace
nn.Module+nn.Parameterwithbrainstate.nn.Module+ParamState.Use
brainstate.transform.grad/jitinstead of PyTorch autograd and scripting.Retrieve and update parameter trees via
graph.treefy_statesandgraph.update_states.Optimisers in
braintools.optimmirror the familiar PyTorch API, operating on state dictionaries.
With these substitutions most PyTorch training loops can be ported one module at a time to BrainState.