Unflatten

Contents

Unflatten#

class brainstate.nn.Unflatten(axis, sizes, name=None, in_size=None)#

Unflatten a tensor dim expanding it to a desired shape. For use with Sequential.

  • dim specifies the dimension of the input tensor to be unflattened, and it can be either int or str when Tensor or NamedTensor is used, respectively.

  • unflattened_size is the new shape of the unflattened dimension of the tensor and it can be a tuple of ints or a list of ints or torch.Size for Tensor input; a NamedShape (tuple of (name, size) tuples) for NamedTensor input.

Shape:
  • Input: \((*, S_{\text{dim}}, *)\), where \(S_{\text{dim}}\) is the size at dimension dim and \(*\) means any number of dimensions including none.

  • Output: \((*, U_1, ..., U_n, *)\), where \(U\) = unflattened_size and \(\prod_{i=1}^n U_i = S_{\text{dim}}\).

Parameters:
  • axis (int) – Dimension to be unflattened.

  • sizes (int | Sequence[int] | integer | Sequence[integer]) – New shape of the unflattened dimension.

  • name (str) – The name of the module.

  • in_size (int | Sequence[int] | integer | Sequence[integer] | None) – The shape of the input tensor.