Getting Started#
Welcome to BrainState! This tutorial will guide you through the basics of using BrainState, a state-based transformation system designed for brain modeling and neural network programming.
By the end of this tutorial, you will:
Understand what BrainState is and why it’s useful
Know how to install and set up BrainState
Learn the core concepts and design philosophy
Build your first simple neural network with BrainState
What is BrainState?#
BrainState is a powerful Python library built on top of JAX that provides:
🧠 Stateful Programming Model: Manage mutable states in a JAX-compatible way
🚀 High Performance: Leverage JAX’s JIT compilation, automatic differentiation, and vectorization
🔧 Modular Design: Build complex models from simple, composable components
🌐 Brain Modeling: Specialized tools for computational neuroscience and brain-inspired computing
BrainState bridges the gap between the functional programming paradigm of JAX and the intuitive, stateful programming style commonly used in neural network frameworks.
Installation and Environment Setup#
Prerequisites#
Before installing BrainState, ensure you have:
Python 3.9 or higher
pip package manager
Installing BrainState#
The easiest way to install BrainState is via pip:
pip install brainstate --upgrade
Installing the Complete Ecosystem#
For a complete brain modeling ecosystem, you can install BrainX, which bundles BrainState with other compatible packages:
pip install BrainX -U
This includes:
brainstate: Core state management and transformations
brainunit: Physical units and dimensional analysis
braintools: Optimization algorithms and utilities
brainpy: Spiking neural network modeling
Verifying Installation#
Let’s verify that BrainState is installed correctly:
import brainstate
import braintools
import jax.numpy as jnp
print(f"BrainState version: {brainstate.__version__}")
print(f"Installation successful! ✓")
BrainState version: 0.2.3
Installation successful! ✓
Core Concepts Overview#
BrainState is built around several key concepts that work together to enable stateful, high-performance neural network programming.
1. State: Managing Mutable Variables#
In pure functional programming (like JAX), all data is immutable. However, neural networks and brain models inherently involve mutable states (e.g., neuron membrane potentials, network weights).
BrainState’s State provides a solution by wrapping mutable variables in a way that’s compatible with JAX transformations.
# Creating a State object
voltage = brainstate.State(jnp.array([0.0, -70.0, -55.0]))
print("Initial voltage:", voltage.value)
# Updating the state
voltage.value = voltage.value + 10.0
print("Updated voltage:", voltage.value)
Initial voltage: [ 0. -70. -55.]
Updated voltage: [ 10. -60. -45.]
Key Types of States:
State: Generic mutable stateParamState: Trainable parameters (weights, biases)HiddenState: Hidden activations (membrane potentials, hidden layer outputs)ShortTermState: Temporary states (spike times, current values)LongTermState: Long-term states (running statistics, momentum)
We’ll explore these in detail in the next tutorial.
2. Module: Building Blocks of Neural Networks#
The Module class (actually graph.Node) is the base class for all neural network components in BrainState. It automatically manages states and provides a clean interface for building complex models.
class SimpleNeuron(brainstate.nn.Module):
"""A simple leaky integrate-and-fire neuron."""
def __init__(self, threshold=1.0, reset=0.0, tau=10.0):
super().__init__()
self.threshold = threshold
self.reset = reset
self.tau = tau
# Membrane potential is a hidden state
self.V = brainstate.HiddenState(jnp.array(0.0))
def __call__(self, I_input):
"""Update neuron state given input current."""
# Leaky integration
dV = (-self.V.value + I_input) / self.tau
self.V.value = self.V.value + dV
# Spike and reset
spike = self.V.value >= self.threshold
self.V.value = jnp.where(spike, self.reset, self.V.value)
return spike
# Create and test the neuron
neuron = SimpleNeuron()
print("Initial voltage:", neuron.V.value)
# Simulate with input current
for t in range(20):
spike = neuron(2.0) # constant input
if spike:
print(f"Spike at time {t}! V={neuron.V.value}")
Initial voltage: 0.0
Spike at time 6! V=0.0
Spike at time 13! V=0.0
3. Transform: JAX Transformations with States#
BrainState provides state-aware versions of JAX transformations:
brainstate.transform.jit: Just-in-time compilationbrainstate.transform.grad: Automatic differentiationbrainstate.transform.vmap: Vectorization (batching)brainstate.transform.scan: Efficient loops
These transformations automatically handle state management for you.
# Reset neuron
neuron.V.value = jnp.array(0.0)
# Simulate with varying input
inputs = jnp.array([1.5, 2.0, 2.5, 3.0, 1.0] * 4)
spikes = brainstate.transform.for_loop(neuron, inputs)
print("Spike train:", spikes.astype(int))
Spike train: [0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0]
4. Random: Stateful Random Number Generation#
BrainState provides a stateful random number generator that’s compatible with JAX’s functional random number generation while maintaining a simple, NumPy-like interface.
# Set random seed for reproducibility
brainstate.random.seed(42)
# Generate random numbers
uniform_samples = brainstate.random.rand(5)
normal_samples = brainstate.random.randn(5)
print("Uniform samples:", uniform_samples)
print("Normal samples:", normal_samples)
Uniform samples: [0.72766423 0.78786755 0.18169427 0.26263022 0.11072934]
Normal samples: [-0.21089035 -1.3627948 -0.04500385 -1.1536394 1.9141139 ]
Hello World: Building Your First Neural Network#
Let’s build a simple feedforward neural network to classify handwritten digits. This example demonstrates the key concepts working together.
Step 1: Define the Network#
class MLP(brainstate.nn.Module):
"""A simple multi-layer perceptron."""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
# Initialize weights and biases as trainable parameters
self.w1 = brainstate.ParamState(brainstate.random.randn(input_dim, hidden_dim) * 0.1)
self.b1 = brainstate.ParamState(jnp.zeros(hidden_dim))
self.w2 = brainstate.ParamState(brainstate.random.randn(hidden_dim, output_dim) * 0.1)
self.b2 = brainstate.ParamState(jnp.zeros(output_dim))
def __call__(self, x):
"""Forward pass through the network."""
# Hidden layer with ReLU activation
hidden = jnp.maximum(0, x @ self.w1.value + self.b1.value)
# Output layer
logits = hidden @ self.w2.value + self.b2.value
return logits
# Create the network
brainstate.random.seed(0)
model = MLP(input_dim=784, hidden_dim=128, output_dim=10)
print("Network created!")
print(f"Total parameters: {784*128 + 128 + 128*10 + 10:,}")
Network created!
Total parameters: 101,770
Step 2: Define Loss Function and Training Step#
def cross_entropy_loss(logits, labels):
"""Compute cross-entropy loss."""
# One-hot encode labels
one_hot_labels = jnp.eye(10)[labels]
# Compute log-softmax
log_probs = logits - jnp.log(jnp.sum(jnp.exp(logits), axis=-1, keepdims=True))
# Compute loss
loss = -jnp.mean(jnp.sum(one_hot_labels * log_probs, axis=-1))
return loss
def accuracy(logits, labels):
"""Compute classification accuracy."""
predictions = jnp.argmax(logits, axis=-1)
return jnp.mean(predictions == labels)
def loss_fn(x, y):
"""Compute loss for the model."""
logits = model(x)
return cross_entropy_loss(logits, y)
# Generate dummy data for demonstration
brainstate.random.seed(42)
X_train = brainstate.random.randn(100, 784) * 0.1 # 100 samples
y_train = brainstate.random.randint(0, 10, 100) # Random labels
# Create gradient function
param_states = brainstate.transform.StateFinder(loss_fn, brainstate.ParamState)(X_train, y_train)
grad_fn = brainstate.transform.grad(loss_fn, grad_states=param_states)
Step 3: Training Loop#
optimizer = braintools.optim.SGD(1e-1)
_ = optimizer.register_trainable_weights(param_states)
@brainstate.transform.jit
def train_step(x, y):
"""Perform one training step."""
# Compute gradients
grads = grad_fn(x, y)
# Update parameters using gradient descent
optimizer.update(grads)
# Compute metrics
logits = model(x)
loss = cross_entropy_loss(logits, y)
acc = accuracy(logits, y)
return loss, acc
# Training loop
print("Starting training...\n")
for epoch in range(10):
loss, acc = train_step(X_train, y_train)
if (epoch + 1) % 2 == 0:
print(f"Epoch {epoch+1:2d}: Loss = {loss:.4f}, Accuracy = {acc:.4f}")
print("\nTraining complete!")
Starting training...
Epoch 2: Loss = 2.2960, Accuracy = 0.1000
Epoch 4: Loss = 2.2739, Accuracy = 0.1200
Epoch 6: Loss = 2.2529, Accuracy = 0.1500
Epoch 8: Loss = 2.2326, Accuracy = 0.1900
Epoch 10: Loss = 2.2130, Accuracy = 0.2300
Training complete!
Step 4: Making Predictions#
@brainstate.transform.jit
def predict(x):
"""Make predictions with the model."""
logits = model(x)
return jnp.argmax(logits, axis=-1)
# Generate test data
X_test = brainstate.random.randn(10, 784) * 0.1
predictions = predict(X_test)
print("Predictions on test data:")
print(predictions)
Predictions on test data:
[3 9 7 7 7 7 7 9 5 0]
Key Takeaways#
Congratulations! You’ve just built your first neural network with BrainState. Here are the key concepts we covered:
States wrap mutable variables and make them compatible with JAX transformations
Modules (via
nn.Module) provide a clean way to organize neural network componentsTransformations like
jitandgradwork seamlessly with stateful codeRandom number generation is stateful yet reproducible
What’s Next?#
Now that you understand the basics, continue with the following tutorials:
State Management - Deep dive into different types of states and advanced state management techniques
Random Number Generation - Learn about BrainState’s random number generation system
Neural Network Modules - Explore pre-built layers and learn to create custom modules
Program Transformations - Master JIT compilation, automatic differentiation, and vectorization
Additional Resources#
Happy coding with BrainState! 🧠✨