Flatten#
- class brainstate.nn.Flatten(start_axis=0, end_axis=-1, in_size=None)#
Flattens a contiguous range of dims into a tensor. For use with
Sequential.- Shape:
Input: \((*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)\),’ where \(S_{i}\) is the size at dimension \(i\) and \(*\) means any number of dimensions including none.
Output: \((*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)\).
- Parameters:
Examples
>>> import brainstate >>> inp = brainstate.random.randn(32, 1, 5, 5) >>> # With default parameters >>> m = Flatten() >>> output = m(inp) >>> output.shape (32, 25) >>> # With non-default parameters >>> m = Flatten(0, 2) >>> output = m(inp) >>> output.shape (160, 5)