Flatten

Contents

Flatten#

class brainstate.nn.Flatten(start_axis=0, end_axis=-1, in_size=None)#

Flattens a contiguous range of dims into a tensor. For use with Sequential.

Shape:
  • Input: \((*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)\),’ where \(S_{i}\) is the size at dimension \(i\) and \(*\) means any number of dimensions including none.

  • Output: \((*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)\).

Parameters:
  • start_axis (int) – First dim to flatten (default = 0).

  • end_axis (int) – Last dim to flatten (default = -1).

  • in_size (int | Sequence[int] | integer | Sequence[integer] | None) – The shape of the input tensor.

Examples

>>> import brainstate
>>> inp = brainstate.random.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = Flatten()
>>> output = m(inp)
>>> output.shape
(32, 25)
>>> # With non-default parameters
>>> m = Flatten(0, 2)
>>> output = m(inp)
>>> output.shape
(160, 5)