Training Artificial Neural Networks#
In recent years, artificial neural networks have developed rapidly and play an important role in neuroscience research. As a high-performance computational framework for brain dynamics modeling, brainstate also supports the training of artificial neural networks, facilitating the integration of neural dynamics models with artificial neural networks.
Here, we will introduce how to train an artificial neural network using brainstate, with an example of a simple 2-layer multilayer perceptron (MLP) for handwritten digit recognition (MNIST).
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset
import braintools
import brainstate
from braintools.metric import softmax_cross_entropy_with_integer_labels
brainstate.__version__
'0.2.3'
Preparing the Dataset#
First, we need to obtain the dataset and wrap it into an iterable object that automatically samples and shuffles the data according to the batch size.
dataset = load_dataset('mnist')
X_train = np.array(np.stack(dataset['train']['image']), dtype=np.uint8)
X_test = np.array(np.stack(dataset['test']['image']), dtype=np.uint8)
X_train = (X_train > 0).astype(jnp.float32)
X_test = (X_test > 0).astype(jnp.float32)
Y_train = np.array(dataset['train']['label'], dtype=np.int32)
Y_test = np.array(dataset['test']['label'], dtype=np.int32)
class Dataset:
def __init__(self, X, Y, batch_size, shuffle=True):
self.X = X
self.Y = Y
self.batch_size = batch_size
self.shuffle = shuffle
self.indices = np.arange(len(X))
self.current_index = 0
if self.shuffle:
np.random.shuffle(self.indices)
def __iter__(self):
self.current_index = 0
if self.shuffle:
np.random.shuffle(self.indices)
return self
def __next__(self):
# Check if all samples have been processed
if self.current_index >= len(self.X):
raise StopIteration
# Define the start and end of the current batch
start = self.current_index
end = start + self.batch_size
if end > len(self.X):
end = len(self.X)
# Update current index
self.current_index = end
# Select batch samples
batch_indices = self.indices[start:end]
batch_X = self.X[batch_indices]
batch_Y = self.Y[batch_indices]
# Ensure batch has consistent shape
if batch_X.ndim == 1:
batch_X = np.expand_dims(batch_X, axis=0)
return batch_X, batch_Y
# Initialize training and testing datasets
batch_size = 32
train_dataset = Dataset(X_train, Y_train, batch_size, shuffle=True)
test_dataset = Dataset(X_test, Y_test, batch_size, shuffle=False)
Defining the Artificial Neural Network#
When defining an artificial neural network in brainstate, you need to inherit the base class brainstate.nn.Module. In the class method __init__(), define the layers in the network (make sure to initialize the base class first using super().__init__()). In the class method __call__(), define the forward pass method of the network.
brainstate also supports defining operations for individual layers in the network. For these custom layers, you need to inherit from the base class brainstate.nn.Module, similar to defining a network.
All quantities that need to change in the model should be encapsulated in the State object. Parameters that need to be updated during training should be encapsulated in a subclass of State called ParamState. Other quantities that need to be updated during training are encapsulated in another subclass of State called ShortTermState.
# Define linear layer
class Linear(brainstate.nn.Module):
def __init__(self, din: int, dout: int):
super().__init__()
self.w = brainstate.ParamState(brainstate.random.rand(din, dout)) # Initialize weight parameters
self.b = brainstate.ParamState(jnp.zeros((dout,))) # Initialize bias parameters
def __call__(self, x):
return x @ self.w.value + self.b.value # Perform linear transformation
# Define a short-term state for counting times called
class Count(brainstate.ShortTermState):
pass
# Define MLP model
class MLP(brainstate.nn.Module):
def __init__(self, din, dhidden, dout):
super().__init__()
self.count = Count(jnp.array(0)) # Count how many times model is called
self.linear1 = Linear(din, dhidden) # brainstate有常规层的实现,可以直接写 self.linear1 = brainstate.nn.Linear(din, dhidden)
self.linear2 = Linear(dhidden, dout)
self.flatten = brainstate.nn.Flatten(start_axis=1) # Flatten images to 1D
self.relu = brainstate.nn.ReLU() # ReLU activation function
def __call__(self, x):
self.count.value += 1 # Increment call count
x = self.flatten(x)
x = self.linear1(x)
x = self.relu(x) # 也兼容jax函数,可以直接写 x = jax.nn.relu(x)
x = self.linear2(x)
return x
# Initialize model with input, hidden, and output layer sizes
model = MLP(din=28*28, dhidden=512, dout=10)
Optimizer Setup#
braintools.optim provides various optimizers to choose from.
After instantiating the optimizer, you need to specify which parameters the optimizer should update by calling optimizer.register_trainable_weights().
In this case, we use brainstate.nn.Module.states() to collect all the State objects of the network nodes and their sub-nodes in the model. We restrict the types of State collected to brainstate.ParamState (in this model, State instances may also have other types like Count, which do not need to be updated by the optimizer, so we apply type restrictions).
# Initialize optimizer and register model parameters
optimizer = braintools.optim.SGD(lr = 1e-3) # Initialize SGD optimizer with learning rate
optimizer.register_trainable_weights(model.states(brainstate.ParamState)) # Register parameters for optimization
SGD(
momentum=0.0,
nesterov=False,
param_states=<braintools.optim.UniqueStateManager object at 0x000001A3AF83CF50>,
weight_decay=0.0,
grad_clip_norm=None,
grad_clip_value=None,
step_count=OptimState(
value=ShapedArray(int32[], weak_type=True)
),
param_groups=[
{
'params': {
('linear1', 'b'): ParamState(
value=ShapedArray(float32[512])
),
('linear1', 'w'): ParamState(
value=ShapedArray(float32[784,512])
),
('linear2', 'b'): ParamState(
value=ShapedArray(float32[10])
),
('linear2', 'w'): ParamState(
value=ShapedArray(float32[512,10])
)
},
'lr': OptimState(
value=ShapedArray(float32[], weak_type=True)
),
'weight_decay': 0.0
}
],
param_groups_opt_states=[],
_schedulers=[],
_lr_scheduler=<braintools.optim.ConstantLR object at 0x000001A3AF83EC90>,
_base_lr=0.001,
_current_lr=OptimState(...),
tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x000001A3AE36CB80>, update=<function chain.<locals>.update_fn at 0x000001A3AE36C860>),
opt_state=OptimState(
value=(ScaleByScheduleState(count=ShapedArray(int32[])),)
)
)
Model Training#
During model training, use the brainstate.transform.grad function to calculate gradients. This function requires the loss function and the parameters (State) for which gradients should be computed.
Then, the gradients are passed to the previously defined optimizer via update() for the update.
To improve computational efficiency and performance, use the brainstate.transform.jit function to decorate the training step function, enabling just-in-time compilation.
# Training step function
@brainstate.transform.jit
def train_step(batch):
x, y = batch
# Define loss function
def loss_fn():
return softmax_cross_entropy_with_integer_labels(model(x), y).mean()
# Compute gradients of the loss with respect to model parameters
grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))()
optimizer.update(grads) # Update parameters using optimizer
Model Testing#
Similarly, use the brainstate.transform.jit function to decorate the testing step function, allowing for just-in-time compilation to improve computational efficiency and performance.
# Testing step function
@brainstate.transform.jit
def test_step(batch):
x, y = batch
y_pred = model(x) # Perform forward pass
loss = softmax_cross_entropy_with_integer_labels(y_pred, y).mean() # Compute loss
correct = (y_pred.argmax(1) == y).sum() # Count correct predictions
return {'loss': loss, 'correct': correct}
Training Process#
This completes the setup and the process for training an artificial neural network with brainstate.
# Execute training and testing
total_steps = 20
for epoch in range(10):
for step, batch in enumerate(train_dataset):
train_step(batch) # Perform training step for each batch
# Calculate test loss and accuracy
test_loss, correct = 0, 0
for step_, test_ in enumerate(test_dataset):
logs = test_step(test_)
test_loss += logs['loss']
correct += logs['correct']
test_loss += logs['loss']
test_loss = test_loss / (step_ + 1)
test_accuracy = correct / len(X_test)
print(f"epoch: {epoch}, test loss: {test_loss}, test accuracy: {test_accuracy}")
print('times model called:', model.count.value) # Output number of model calls
epoch: 0, test loss: 448.7932434082031, test accuracy: 0.26989999413490295
epoch: 1, test loss: 150.93972778320312, test accuracy: 0.6607000231742859
epoch: 2, test loss: 105.59516906738281, test accuracy: 0.7075999975204468
epoch: 3, test loss: 80.24203491210938, test accuracy: 0.7462999820709229
epoch: 4, test loss: 120.68618774414062, test accuracy: 0.7113999724388123
epoch: 5, test loss: 39.29928207397461, test accuracy: 0.8547999858856201
epoch: 6, test loss: 89.67916107177734, test accuracy: 0.7944999933242798
epoch: 7, test loss: 53.42087173461914, test accuracy: 0.8274000287055969
epoch: 8, test loss: 35.70460510253906, test accuracy: 0.8694000244140625
epoch: 9, test loss: 37.64791488647461, test accuracy: 0.8648999929428101
times model called: 21880