Batching#
- class brainstate.mixin.Batching(batch_size=1, batch_axis=0)[source]#
Mode indicating batched computation.
This mode specifies that computations should be performed on batches of data, including information about the batch size and which axis represents the batch dimension.
- Parameters:
Examples
Basic batching configuration:
>>> import brainstate >>> >>> # Create a batching mode >>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=0) >>> print(batching) # Batching(in_size=32, axis=0) >>> >>> # Access batch parameters >>> print(f"Processing {batching.batch_size} samples at once") >>> print(f"Batch dimension is axis {batching.batch_axis}")
Using in a model:
>>> import jax.numpy as jnp >>> >>> class BatchedModel: ... def __init__(self): ... self.mode = None ... ... def set_batch_mode(self, batch_size, batch_axis=0): ... self.mode = brainstate.mixin.Batching(batch_size, batch_axis) ... ... def process(self, x): ... if self.mode is not None and self.mode.has(brainstate.mixin.Batching): ... # Process in batches ... batch_size = self.mode.batch_size ... axis = self.mode.batch_axis ... return jnp.mean(x, axis=axis, keepdims=True) ... return x >>> >>> model = BatchedModel() >>> model.set_batch_mode(batch_size=64) >>> >>> # Process batched data >>> data = jnp.random.randn(64, 100) # 64 samples, 100 features >>> result = model.process(data)
Combining with other modes:
>>> # Combine batching with training mode >>> training = brainstate.mixin.Training() >>> batching = brainstate.mixin.Batching(batch_size=128) >>> combined = brainstate.mixin.JointMode(training, batching) >>> >>> # Use in a training loop >>> def train_step(model, data, mode): ... if mode.has(brainstate.mixin.Batching): ... # Split data into batches ... batch_size = mode.batch_size ... # ... batched processing ... ... if mode.has(brainstate.mixin.Training): ... # Apply training-specific operations ... # ... training logic ... ... pass