Mixin System#

This tutorial explains the mixin utilities that ship with brainstate. After working through the examples you will:

  • Understand what a mixin is and when to use one.

  • Reuse behaviors by inheriting from brainstate.mixin.Mixin.

  • Capture reusable constructor presets with ParamDesc and ParamDescriber.

  • Express rich type expectations with JointTypes and OneOfTypes.

  • Control runtime behaviour with the built-in mode mixins such as Training, Batching, and JointMode.

import datetime
from dataclasses import dataclass

import jax.numpy as jnp

import brainstate
from brainstate import mixin

What is a mixin?#

A mixin is a lightweight class that contributes behaviour (methods or attributes) without forcing a rigid inheritance hierarchy. In BrainState every mixin inherits from brainstate.mixin.Mixin, signalling that the class provides optional behaviour and should not define its own __init__. Mixins are usually paired with core components such as brainstate.nn.Module to keep reusable code close to the consumer.

class LoggingMixin(mixin.Mixin):
    """Attach timestamped logging to any class without touching its constructor."""

    def log(self, message: str) -> None:
        stamp = datetime.datetime.now().strftime('%H:%M:%S')
        print(f'[LOG {stamp}] {self.__class__.__name__}: {message}')


class Accumulator(brainstate.nn.Module, LoggingMixin):
    """Simple module that reuses the logging helper."""

    def __init__(self):
        super().__init__()
        self.total = 0.0

    def add(self, value):
        self.total += float(value)
        self.log(f'updated running total to {self.total:.2f}')
        return self.total


acc = Accumulator()
_ = acc.add(1.25)
_ = acc.add(2.75)

Design tips#

  • A mixin should only provide behaviour; avoid introducing new required constructor arguments.

  • Keep mixins focused. Several small mixins compose better than a single, opinionated base class.

  • Document expectations about host classes (e.g. attributes a mixin reads or writes).

Parameter descriptors with ParamDesc#

ParamDesc helps you capture reusable constructor presets. The desc() class method stores the provided arguments inside a ParamDescriber, which you can later call to instantiate new objects while still overriding any argument on demand.

class DenseBlock(mixin.ParamDesc):
    """Toy layer that records its configuration for inspection."""

    def __init__(self, in_features: int, out_features: int, *, activation: str = 'relu'):
        self.in_features = in_features
        self.out_features = out_features
        self.activation = activation

    def summary(self) -> str:
        return f'{self.activation} dense {self.in_features} → {self.out_features}'


encoder_block = DenseBlock.desc(256, 128, activation='gelu')
decoder_block = DenseBlock.desc(128, 64, activation='relu')

print(encoder_block().summary())
print(encoder_block(activation='relu').summary())  # override at call time
print(decoder_block().summary())

ParamDesc stores descriptors in a hashable structure. This plays nicely with caching systems because descriptor.identifier is safe to use as a dictionary key.

print(encoder_block.identifier)

Using ParamDescriber directly#

If you want to describe classes that do not inherit from ParamDesc, you can work with ParamDescriber manually.

@dataclass
class OptimConfig:
    lr: float
    beta1: float = 0.9
    beta2: float = 0.999


adam_template = mixin.ParamDescriber(OptimConfig, lr=1e-3, beta1=0.95)
opt_a = adam_template()
opt_b = adam_template(lr=5e-4)  # override a keyword

print(opt_a)
print(opt_b)

Type combinators: JointTypes and OneOfTypes#

BrainState ships two helpers that make intent explicit when a value must satisfy multiple interfaces or just one of several options:

  • JointTypes[A, B, ...] behaves like an intersection — an instance must satisfy all listed types.

  • OneOfTypes[A, B, ...] behaves like a union — an instance may satisfy any listed type.

class Persistable:
    def save(self):
        raise NotImplementedError


class Visualisable:
    def plot(self):
        raise NotImplementedError


class Report(Persistable, Visualisable):
    def save(self):
        return 'saved to disk'

    def plot(self):
        return 'rendering preview'


FullFeatureType = mixin.JointTypes[Persistable, Visualisable]
OptionalNumber = mixin.OneOfTypes[int, float, type(None)]

report = Report()
print(isinstance(report, FullFeatureType))
print(isinstance(3.14, OptionalNumber), isinstance(None, OptionalNumber))

Mode mixins for runtime behaviour#

Mode objects capture the context in which computation happens. The base Mode class is lightweight, and the built-ins Training, Batching, and JointMode cover common runtime switches.

class ToyPipeline:
    """A tiny module that responds to different mode configurations."""

    def __init__(self):
        self.mode: mixin.Mode = mixin.Mode()

    def set_mode(self, *modes: mixin.Mode):
        if not modes:
            self.mode = mixin.Mode()
        elif len(modes) == 1:
            self.mode = modes[0]
        else:
            self.mode = mixin.JointMode(*modes)

    def forward(self, values):
        x = jnp.asarray(values, dtype=jnp.float32)
        if self.mode.has(mixin.Training):
            x = x + 0.1  # emulate noise or dropout
        if self.mode.has(mixin.Batching):
            batch = self.mode.batch_size
            x = x.reshape((batch, -1)).mean(axis=1)
        return x


pipeline = ToyPipeline()
print('default', pipeline.forward(jnp.arange(4.0)))

pipeline.set_mode(mixin.Training())
print('training', pipeline.forward(jnp.arange(4.0)))

pipeline.set_mode(mixin.Training(), mixin.Batching(batch_size=2))
print('joint', pipeline.forward(jnp.arange(4.0)))
print('joint exposes batch size:', pipeline.mode.batch_size)

The joint mode exposes the attributes of its members, so accessing pipeline.mode.batch_size works even though the current mode is a JointMode instance.

Putting it together#

When you combine these mixin tools you can:

  1. Add reusable behaviour (logging, validation, metrics) without disturbing core module hierarchies.

  2. Parameterise component templates and reuse them safely through descriptors.

  3. Encode clear expectations about inputs or collaborators via JointTypes/OneOfTypes.

  4. Toggle runtime semantics with mode objects instead of ad-hoc boolean flags.

Next steps#

  • Audit your own modules for behaviours that could live in a mixin.

  • Wrap frequently reused constructor arguments with ParamDesc.

  • Adopt mode objects in your training scripts to centralise feature flags (e.g. evaluation vs training).

  • Explore brainstate.mixin.not_implemented to clearly mark unsupported operations.