JointMode#

class brainstate.mixin.JointMode(*modes)[source]#

A mode that combines multiple modes simultaneously.

JointMode allows expressing that a computation is in multiple modes at once, such as being both in training mode and batching mode. This is useful for complex scenarios where multiple behavioral aspects need to be active.

Parameters:

*modes (Mode) – The modes to combine.

modes#

The individual modes that are combined.

Type:

tuple of Mode

types#

The types of the combined modes.

Type:

set of type

Raises:

TypeError – If any of the provided arguments is not a Mode instance.

Examples

Combining training and batching modes:

>>> import brainstate
>>>
>>> # Create individual modes
>>> training = brainstate.mixin.Training()
>>> batching = brainstate.mixin.Batching(batch_size=32)
>>>
>>> # Combine them
>>> joint = brainstate.mixin.JointMode(training, batching)
>>> print(joint)  # JointMode(Training, Batching(in_size=32, axis=0))
>>>
>>> # Check if specific modes are present
>>> assert joint.has(brainstate.mixin.Training)
>>> assert joint.has(brainstate.mixin.Batching)
>>>
>>> # Access attributes from combined modes
>>> print(joint.batch_size)  # 32 (from Batching mode)

Using in model configuration:

>>> class NeuralNetwork:
...     def __init__(self):
...         self.mode = None
...
...     def set_train_mode(self, batch_size=1):
...         # Set both training and batching modes
...         training = brainstate.mixin.Training()
...         batching = brainstate.mixin.Batching(batch_size=batch_size)
...         self.mode = brainstate.mixin.JointMode(training, batching)
...
...     def forward(self, x):
...         if self.mode.has(brainstate.mixin.Training):
...             x = self.apply_dropout(x)
...
...         if self.mode.has(brainstate.mixin.Batching):
...             # Process in batches
...             batch_size = self.mode.batch_size
...             return self.batch_process(x, batch_size)
...
...         return self.process(x)
>>>
>>> model = NeuralNetwork()
>>> model.set_train_mode(batch_size=64)
has(mode)[source]#

Check whether any of the combined modes includes the desired type.

Parameters:

mode (type) – The mode type to check for.

Returns:

True if any of the combined modes is or inherits from the specified type.

Return type:

bool

Examples

>>> import brainstate
>>>
>>> training = brainstate.mixin.Training()
>>> batching = brainstate.mixin.Batching(batch_size=16)
>>> joint = brainstate.mixin.JointMode(training, batching)
>>>
>>> assert joint.has(brainstate.mixin.Training)
>>> assert joint.has(brainstate.mixin.Batching)
>>> assert joint.has(brainstate.mixin.Mode)  # Base class
is_a(cls)[source]#

Check whether the joint mode is exactly the desired combined type.

This is a complex check that verifies the joint mode matches a specific combination of types.

Parameters:

cls (type) – The combined type to check against.

Returns:

True if the joint mode exactly matches the specified type combination.

Return type:

bool