Mode

Mode#

class brainstate.mixin.Mode[source]#

Base class for computation behavior modes.

Modes are used to represent different computational contexts or behaviors, such as training vs evaluation, batched vs single-sample processing, etc. They provide a flexible way to configure how models and components behave in different scenarios.

Examples

Creating a custom mode:

>>> import brainstate
>>>
>>> class InferenceMode(brainstate.mixin.Mode):
...     def __init__(self, use_cache=True):
...         self.use_cache = use_cache
>>>
>>> # Create mode instances
>>> inference = InferenceMode(use_cache=True)
>>> print(inference)  # Output: InferenceMode

Checking mode types:

>>> class FastMode(brainstate.mixin.Mode):
...     pass
>>>
>>> class SlowMode(brainstate.mixin.Mode):
...     pass
>>>
>>> fast = FastMode()
>>> slow = SlowMode()
>>>
>>> # Check exact mode type
>>> assert fast.is_a(FastMode)
>>> assert not fast.is_a(SlowMode)
>>>
>>> # Check if mode is an instance of a type
>>> assert fast.has(brainstate.mixin.Mode)

Using modes in a model:

>>> class Model:
...     def __init__(self):
...         self.mode = brainstate.mixin.Training()
...
...     def forward(self, x):
...         if self.mode.has(brainstate.mixin.Training):
...             # Training-specific logic
...             return self.train_forward(x)
...         else:
...             # Inference logic
...             return self.eval_forward(x)
...
...     def train_forward(self, x):
...         return x + 0.1  # Add noise during training
...
...     def eval_forward(self, x):
...         return x  # No noise during evaluation
has(mode)[source]#

Check whether the mode includes the desired mode type.

This checks if the current mode is an instance of the specified type, including checking for subclasses.

Parameters:

mode (type) – The mode type to check for.

Returns:

True if this mode is an instance of the specified type.

Return type:

bool

Examples

>>> import brainstate
>>>
>>> # Create a custom mode that extends Training
>>> class AdvancedTraining(brainstate.mixin.Training):
...     pass
>>>
>>> advanced = AdvancedTraining()
>>> assert advanced.has(brainstate.mixin.Training)  # True (subclass)
>>> assert advanced.has(brainstate.mixin.Mode)      # True (base class)
is_a(mode)[source]#

Check whether the mode is exactly the desired mode type.

This performs an exact type match, not checking for subclasses.

Parameters:

mode (type) – The mode type to check against.

Returns:

True if this mode is exactly of the specified type.

Return type:

bool

Examples

>>> import brainstate
>>>
>>> training_mode = brainstate.mixin.Training()
>>> assert training_mode.is_a(brainstate.mixin.Training)
>>> assert not training_mode.is_a(brainstate.mixin.Batching)