Batching#

class brainstate.mixin.Batching(batch_size=1, batch_axis=0)[source]#

Mode indicating batched computation.

This mode specifies that computations should be performed on batches of data, including information about the batch size and which axis represents the batch dimension.

Parameters:
  • batch_size (int) – The size of each batch.

  • batch_axis (int) – The axis along which batching occurs.

batch_size#

The number of samples in each batch.

Type:

int

batch_axis#

The axis index representing the batch dimension.

Type:

int

Examples

Basic batching configuration:

>>> import brainstate
>>>
>>> # Create a batching mode
>>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=0)
>>> print(batching)  # Batching(in_size=32, axis=0)
>>>
>>> # Access batch parameters
>>> print(f"Processing {batching.batch_size} samples at once")
>>> print(f"Batch dimension is axis {batching.batch_axis}")

Using in a model:

>>> import jax.numpy as jnp
>>>
>>> class BatchedModel:
...     def __init__(self):
...         self.mode = None
...
...     def set_batch_mode(self, batch_size, batch_axis=0):
...         self.mode = brainstate.mixin.Batching(batch_size, batch_axis)
...
...     def process(self, x):
...         if self.mode is not None and self.mode.has(brainstate.mixin.Batching):
...             # Process in batches
...             batch_size = self.mode.batch_size
...             axis = self.mode.batch_axis
...             return jnp.mean(x, axis=axis, keepdims=True)
...         return x
>>>
>>> model = BatchedModel()
>>> model.set_batch_mode(batch_size=64)
>>>
>>> # Process batched data
>>> data = jnp.random.randn(64, 100)  # 64 samples, 100 features
>>> result = model.process(data)

Combining with other modes:

>>> # Combine batching with training mode
>>> training = brainstate.mixin.Training()
>>> batching = brainstate.mixin.Batching(batch_size=128)
>>> combined = brainstate.mixin.JointMode(training, batching)
>>>
>>> # Use in a training loop
>>> def train_step(model, data, mode):
...     if mode.has(brainstate.mixin.Batching):
...         # Split data into batches
...         batch_size = mode.batch_size
...         # ... batched processing ...
...     if mode.has(brainstate.mixin.Training):
...         # Apply training-specific operations
...         # ... training logic ...
...     pass