Creating Custom Models#
Guide to implementing custom neural mass models.
Model Template#
import brainstate
import brainunit as u
import jax.numpy as jnp
class MyCustomModel(brainstate.nn.Dynamics):
"""Custom neural mass model.
Args:
in_size: Number of regions/nodes
tau: Time constant (ms)
other_param: Description
"""
def __init__(self, in_size, tau=10.*u.ms, other_param=1.0):
super().__init__()
# Store parameters
self.in_size = in_size
self.tau = tau
self.other_param = other_param
# Initialize state variables
self.x = brainstate.HiddenState(
brainstate.init.Constant(0.),
sharding=brainstate.ShardingMode.REPLICATED
)
def init_state(self, batch_size=None):
"""Initialize state variables."""
shape = (batch_size, *self.in_size) if batch_size else self.in_size
self.x.value = jnp.zeros(shape)
def update(self, external_input=0.):
"""Update dynamics by one time step.
Args:
external_input: External driving input
Returns:
Observable output
"""
# Get current state
x = self.x.value
# Compute dynamics (example: dx/dt = -x/tau + input)
dx_dt = -x / self.tau + external_input
# Update state (Euler integration, dt assumed 1 ms)
dt = 1 * u.ms
self.x.value = x + dx_dt * dt
# Return observable
return self.x.value
Example: Custom Oscillator#
import brainstate
import brainunit as u
import jax.numpy as jnp
class DampedOscillator(brainstate.nn.Dynamics):
"""Damped harmonic oscillator.
Equations:
dx/dt = v
dv/dt = -omega^2 * x - 2*zeta*omega*v + F_ext
Args:
in_size: Number of oscillators
omega: Natural frequency (Hz)
zeta: Damping ratio (dimensionless)
"""
def __init__(self, in_size, omega=10*u.Hz, zeta=0.1):
super().__init__()
self.in_size = in_size
self.omega = omega
self.zeta = zeta
# State variables: position and velocity
self.x = brainstate.HiddenState(brainstate.init.Constant(0.))
self.v = brainstate.HiddenState(brainstate.init.Constant(0.))
def init_state(self, batch_size=None):
shape = (batch_size, *self.in_size) if batch_size else self.in_size
self.x.value = jnp.zeros(shape)
self.v.value = jnp.zeros(shape)
def update(self, F_ext=0.):
x = self.x.value
v = self.v.value
# Dynamics
dx_dt = v
dv_dt = -(self.omega**2) * x - 2*self.zeta*self.omega*v + F_ext
# Euler integration
dt = 1 * u.ms
self.x.value = x + dx_dt * dt
self.v.value = v + dv_dt * dt
return self.x.value
Adding Noise Support#
class MyModel(brainstate.nn.Dynamics):
def __init__(self, in_size, noise=None):
super().__init__()
self.in_size = in_size
self.noise = noise # Optional noise source
self.x = brainstate.HiddenState(brainstate.init.Constant(0.))
def init_state(self, batch_size=None):
shape = (batch_size, *self.in_size) if batch_size else self.in_size
self.x.value = jnp.zeros(shape)
# Initialize noise if present
if self.noise is not None:
self.noise.init_all_states(batch_size)
def update(self, external_input=0.):
x = self.x.value
# Dynamics
dx_dt = -x + external_input
# Add noise if present
if self.noise is not None:
noise_value = self.noise.update()
dx_dt += noise_value
# Integrate
dt = 1 * u.ms
self.x.value = x + dx_dt * dt
return self.x.value
Unit Handling#
Ensure dimensional consistency:
import brainunit as u
# Parameters with units
tau = 10 * u.ms
omega = 2 * jnp.pi * 10 * u.Hz
# Computations preserve units
frequency = 1 / tau # Result: Hz
period = 1 / omega # Result: seconds
Testing Your Model#
import pytest
import jax.numpy as jnp
def test_my_custom_model():
# Create model
model = MyCustomModel(in_size=10)
model.init_all_states()
# Test update
output = model.update(external_input=1.0)
# Check shape
assert output.shape == (10,)
# Test with batch
model.init_all_states(batch_size=32)
output = model.update()
assert output.shape == (32, 10)
Documentation#
Add comprehensive docstrings:
class MyModel(brainstate.nn.Dynamics):
"""One-line description.
Detailed explanation of the model, including:
- Mathematical equations
- Biological interpretation
- References to papers
Mathematical Details:
.. math::
\\frac{dx}{dt} = f(x, t)
Args:
in_size: Shape of input/number of regions
param1: Description with units
param2: Description with default behavior
Examples:
>>> model = MyModel(in_size=10)
>>> model.init_all_states()
>>> output = model.update()
References:
Author et al. (2020). Paper title. Journal.
"""
Contributing Your Model#
Implement model following template
Add tests in
tests/Create example notebook in
examples/Update API documentation
Submit pull request
See Also#
Architecture - Package structure
Extending Noise Processes - Custom noise processes
Contributing Guide - Contribution guidelines
brainstatedocumentation