Mode#
- class brainstate.mixin.Mode[source]#
Base class for computation behavior modes.
Modes are used to represent different computational contexts or behaviors, such as training vs evaluation, batched vs single-sample processing, etc. They provide a flexible way to configure how models and components behave in different scenarios.
Examples
Creating a custom mode:
>>> import brainstate >>> >>> class InferenceMode(brainstate.mixin.Mode): ... def __init__(self, use_cache=True): ... self.use_cache = use_cache >>> >>> # Create mode instances >>> inference = InferenceMode(use_cache=True) >>> print(inference) # Output: InferenceMode
Checking mode types:
>>> class FastMode(brainstate.mixin.Mode): ... pass >>> >>> class SlowMode(brainstate.mixin.Mode): ... pass >>> >>> fast = FastMode() >>> slow = SlowMode() >>> >>> # Check exact mode type >>> assert fast.is_a(FastMode) >>> assert not fast.is_a(SlowMode) >>> >>> # Check if mode is an instance of a type >>> assert fast.has(brainstate.mixin.Mode)
Using modes in a model:
>>> class Model: ... def __init__(self): ... self.mode = brainstate.mixin.Training() ... ... def forward(self, x): ... if self.mode.has(brainstate.mixin.Training): ... # Training-specific logic ... return self.train_forward(x) ... else: ... # Inference logic ... return self.eval_forward(x) ... ... def train_forward(self, x): ... return x + 0.1 # Add noise during training ... ... def eval_forward(self, x): ... return x # No noise during evaluation
- has(mode)[source]#
Check whether the mode includes the desired mode type.
This checks if the current mode is an instance of the specified type, including checking for subclasses.
- Parameters:
mode (
type) – The mode type to check for.- Returns:
True if this mode is an instance of the specified type.
- Return type:
Examples
>>> import brainstate >>> >>> # Create a custom mode that extends Training >>> class AdvancedTraining(brainstate.mixin.Training): ... pass >>> >>> advanced = AdvancedTraining() >>> assert advanced.has(brainstate.mixin.Training) # True (subclass) >>> assert advanced.has(brainstate.mixin.Mode) # True (base class)
- is_a(mode)[source]#
Check whether the mode is exactly the desired mode type.
This performs an exact type match, not checking for subclasses.
- Parameters:
mode (
type) – The mode type to check against.- Returns:
True if this mode is exactly of the specified type.
- Return type:
Examples
>>> import brainstate >>> >>> training_mode = brainstate.mixin.Training() >>> assert training_mode.is_a(brainstate.mixin.Training) >>> assert not training_mode.is_a(brainstate.mixin.Batching)