JointMode#
- class brainstate.mixin.JointMode(*modes)[source]#
A mode that combines multiple modes simultaneously.
JointMode allows expressing that a computation is in multiple modes at once, such as being both in training mode and batching mode. This is useful for complex scenarios where multiple behavioral aspects need to be active.
- Parameters:
*modes (
Mode) – The modes to combine.
- Raises:
TypeError – If any of the provided arguments is not a Mode instance.
Examples
Combining training and batching modes:
>>> import brainstate >>> >>> # Create individual modes >>> training = brainstate.mixin.Training() >>> batching = brainstate.mixin.Batching(batch_size=32) >>> >>> # Combine them >>> joint = brainstate.mixin.JointMode(training, batching) >>> print(joint) # JointMode(Training, Batching(in_size=32, axis=0)) >>> >>> # Check if specific modes are present >>> assert joint.has(brainstate.mixin.Training) >>> assert joint.has(brainstate.mixin.Batching) >>> >>> # Access attributes from combined modes >>> print(joint.batch_size) # 32 (from Batching mode)
Using in model configuration:
>>> class NeuralNetwork: ... def __init__(self): ... self.mode = None ... ... def set_train_mode(self, batch_size=1): ... # Set both training and batching modes ... training = brainstate.mixin.Training() ... batching = brainstate.mixin.Batching(batch_size=batch_size) ... self.mode = brainstate.mixin.JointMode(training, batching) ... ... def forward(self, x): ... if self.mode.has(brainstate.mixin.Training): ... x = self.apply_dropout(x) ... ... if self.mode.has(brainstate.mixin.Batching): ... # Process in batches ... batch_size = self.mode.batch_size ... return self.batch_process(x, batch_size) ... ... return self.process(x) >>> >>> model = NeuralNetwork() >>> model.set_train_mode(batch_size=64)
- has(mode)[source]#
Check whether any of the combined modes includes the desired type.
- Parameters:
mode (
type) – The mode type to check for.- Returns:
True if any of the combined modes is or inherits from the specified type.
- Return type:
Examples
>>> import brainstate >>> >>> training = brainstate.mixin.Training() >>> batching = brainstate.mixin.Batching(batch_size=16) >>> joint = brainstate.mixin.JointMode(training, batching) >>> >>> assert joint.has(brainstate.mixin.Training) >>> assert joint.has(brainstate.mixin.Batching) >>> assert joint.has(brainstate.mixin.Mode) # Base class