DropoutFixed

DropoutFixed#

class brainstate.nn.DropoutFixed(in_size, prob=0.5, name=None)#

A dropout layer with a fixed dropout mask along the time axis.

In training, to compensate for the fraction of input values dropped, all surviving values are multiplied by 1 / (1 - prob).

This layer is active only during training (mode=brainstate.mixin.Training). In other circumstances it is a no-op.

This kind of Dropout is particularly useful for spiking neural networks (SNNs) where the same dropout mask needs to be applied across multiple time steps within a single mini-batch iteration.

Parameters:
  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The size of the input tensor.

  • prob (float) – Probability to keep element of the tensor. Default is 0.5.

  • name (str | None) – The name of the dynamic system.

Notes

As described in [2], there is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of \(p\)) are disconnected from the network while weighting by its posterior probability (\(1-p\)).

However, in SNNs, each iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate the output error and modify the network parameters only at the last time step. For dropout to be effective in our training method, it has to be ensured that the set of connected units within an iteration of mini-batch data is not changed, such that the neural network is constituted by the same random subset of units during each forward propagation within a single iteration.

On the other hand, if the units are randomly connected at each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire time window within an iteration.

References

Examples

>>> import brainstate
>>> layer = brainstate.nn.DropoutFixed(in_size=(20,), prob=0.8)
>>> layer.init_state(batch_size=10)
>>> x = brainstate.random.randn(10, 20)
>>> with brainstate.environ.context(fit=True):
...     output = layer.update(x)
>>> output.shape
(10, 20)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.