Training

Contents

Training#

class brainstate.mixin.Training[source]#

Mode indicating training computation.

This mode specifies that the model is in training mode, which typically enables behaviors like dropout, batch normalization in training mode, gradient computation, etc.

Examples

Basic training mode:

>>> import brainstate
>>>
>>> # Create training mode
>>> training = brainstate.mixin.Training()
>>> print(training)  # Training
>>>
>>> # Check mode
>>> assert training.is_a(brainstate.mixin.Training)
>>> assert training.has(brainstate.mixin.Mode)

Using in a model with dropout:

>>> import brainstate
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> class ModelWithDropout:
...     def __init__(self, dropout_rate=0.5):
...         self.dropout_rate = dropout_rate
...         self.mode = None
...
...     def set_training(self, is_training=True):
...         if is_training:
...             self.mode = brainstate.mixin.Training()
...         else:
...             self.mode = brainstate.mixin.Mode()  # Evaluation mode
...
...     def forward(self, x, rng_key):
...         # Apply dropout only during training
...         if self.mode is not None and self.mode.has(brainstate.mixin.Training):
...             keep_prob = 1.0 - self.dropout_rate
...             mask = jax.random.bernoulli(rng_key, keep_prob, x.shape)
...             x = jnp.where(mask, x / keep_prob, 0)
...         return x
>>>
>>> model = ModelWithDropout()
>>>
>>> # Training mode
>>> model.set_training(True)
>>> key = jax.random.PRNGKey(0)
>>> x_train = jnp.ones((10, 20))
>>> out_train = model.forward(x_train, key)  # Dropout applied
>>>
>>> # Evaluation mode
>>> model.set_training(False)
>>> out_eval = model.forward(x_train, key)  # No dropout

Combining with batching:

>>> # Create combined training and batching mode
>>> training = brainstate.mixin.Training()
>>> batching = brainstate.mixin.Batching(batch_size=32)
>>> mode = brainstate.mixin.JointMode(training, batching)
>>>
>>> # Use in training configuration
>>> class Trainer:
...     def __init__(self, model, mode):
...         self.model = model
...         self.mode = mode
...
...     def train_epoch(self, data):
...         if self.mode.has(brainstate.mixin.Training):
...             # Enable training-specific behaviors
...             self.model.set_training(True)
...
...         if self.mode.has(brainstate.mixin.Batching):
...             # Process in batches
...             batch_size = self.mode.batch_size
...             # ... batched training loop ...
...         pass