Mixin System#
This tutorial explains the mixin utilities that ship with brainstate. After working through the examples you will:
Understand what a mixin is and when to use one.
Reuse behaviors by inheriting from
brainstate.mixin.Mixin.Capture reusable constructor presets with
ParamDescandParamDescriber.Express rich type expectations with
JointTypesandOneOfTypes.Control runtime behaviour with the built-in mode mixins such as
Training,Batching, andJointMode.
import datetime
from dataclasses import dataclass
import jax.numpy as jnp
import brainstate
from brainstate import mixin
What is a mixin?#
A mixin is a lightweight class that contributes behaviour (methods or attributes) without forcing a rigid inheritance hierarchy.
In BrainState every mixin inherits from brainstate.mixin.Mixin, signalling that the class
provides optional behaviour and should not define its own __init__.
Mixins are usually paired with core components such as brainstate.nn.Module to keep reusable code close to the consumer.
class LoggingMixin(mixin.Mixin):
"""Attach timestamped logging to any class without touching its constructor."""
def log(self, message: str) -> None:
stamp = datetime.datetime.now().strftime('%H:%M:%S')
print(f'[LOG {stamp}] {self.__class__.__name__}: {message}')
class Accumulator(brainstate.nn.Module, LoggingMixin):
"""Simple module that reuses the logging helper."""
def __init__(self):
super().__init__()
self.total = 0.0
def add(self, value):
self.total += float(value)
self.log(f'updated running total to {self.total:.2f}')
return self.total
acc = Accumulator()
_ = acc.add(1.25)
_ = acc.add(2.75)
Design tips#
A mixin should only provide behaviour; avoid introducing new required constructor arguments.
Keep mixins focused. Several small mixins compose better than a single, opinionated base class.
Document expectations about host classes (e.g. attributes a mixin reads or writes).
Parameter descriptors with ParamDesc#
ParamDesc helps you capture reusable constructor presets.
The desc() class method stores the provided arguments inside a ParamDescriber, which you can later call
to instantiate new objects while still overriding any argument on demand.
class DenseBlock(mixin.ParamDesc):
"""Toy layer that records its configuration for inspection."""
def __init__(self, in_features: int, out_features: int, *, activation: str = 'relu'):
self.in_features = in_features
self.out_features = out_features
self.activation = activation
def summary(self) -> str:
return f'{self.activation} dense {self.in_features} → {self.out_features}'
encoder_block = DenseBlock.desc(256, 128, activation='gelu')
decoder_block = DenseBlock.desc(128, 64, activation='relu')
print(encoder_block().summary())
print(encoder_block(activation='relu').summary()) # override at call time
print(decoder_block().summary())
ParamDesc stores descriptors in a hashable structure. This plays nicely with caching systems because
descriptor.identifier is safe to use as a dictionary key.
print(encoder_block.identifier)
Using ParamDescriber directly#
If you want to describe classes that do not inherit from ParamDesc, you can work with
ParamDescriber manually.
@dataclass
class OptimConfig:
lr: float
beta1: float = 0.9
beta2: float = 0.999
adam_template = mixin.ParamDescriber(OptimConfig, lr=1e-3, beta1=0.95)
opt_a = adam_template()
opt_b = adam_template(lr=5e-4) # override a keyword
print(opt_a)
print(opt_b)
Type combinators: JointTypes and OneOfTypes#
BrainState ships two helpers that make intent explicit when a value must satisfy multiple interfaces or just one of several options:
JointTypes[A, B, ...]behaves like an intersection — an instance must satisfy all listed types.OneOfTypes[A, B, ...]behaves like a union — an instance may satisfy any listed type.
class Persistable:
def save(self):
raise NotImplementedError
class Visualisable:
def plot(self):
raise NotImplementedError
class Report(Persistable, Visualisable):
def save(self):
return 'saved to disk'
def plot(self):
return 'rendering preview'
FullFeatureType = mixin.JointTypes[Persistable, Visualisable]
OptionalNumber = mixin.OneOfTypes[int, float, type(None)]
report = Report()
print(isinstance(report, FullFeatureType))
print(isinstance(3.14, OptionalNumber), isinstance(None, OptionalNumber))
Mode mixins for runtime behaviour#
Mode objects capture the context in which computation happens.
The base Mode class is lightweight, and the built-ins Training, Batching, and JointMode cover
common runtime switches.
class ToyPipeline:
"""A tiny module that responds to different mode configurations."""
def __init__(self):
self.mode: mixin.Mode = mixin.Mode()
def set_mode(self, *modes: mixin.Mode):
if not modes:
self.mode = mixin.Mode()
elif len(modes) == 1:
self.mode = modes[0]
else:
self.mode = mixin.JointMode(*modes)
def forward(self, values):
x = jnp.asarray(values, dtype=jnp.float32)
if self.mode.has(mixin.Training):
x = x + 0.1 # emulate noise or dropout
if self.mode.has(mixin.Batching):
batch = self.mode.batch_size
x = x.reshape((batch, -1)).mean(axis=1)
return x
pipeline = ToyPipeline()
print('default', pipeline.forward(jnp.arange(4.0)))
pipeline.set_mode(mixin.Training())
print('training', pipeline.forward(jnp.arange(4.0)))
pipeline.set_mode(mixin.Training(), mixin.Batching(batch_size=2))
print('joint', pipeline.forward(jnp.arange(4.0)))
print('joint exposes batch size:', pipeline.mode.batch_size)
The joint mode exposes the attributes of its members, so accessing pipeline.mode.batch_size works even
though the current mode is a JointMode instance.
Putting it together#
When you combine these mixin tools you can:
Add reusable behaviour (logging, validation, metrics) without disturbing core module hierarchies.
Parameterise component templates and reuse them safely through descriptors.
Encode clear expectations about inputs or collaborators via
JointTypes/OneOfTypes.Toggle runtime semantics with mode objects instead of ad-hoc boolean flags.
Next steps#
Audit your own modules for behaviours that could live in a mixin.
Wrap frequently reused constructor arguments with
ParamDesc.Adopt mode objects in your training scripts to centralise feature flags (e.g. evaluation vs training).
Explore
brainstate.mixin.not_implementedto clearly mark unsupported operations.