ConvTranspose1d#

class brainstate.nn.ConvTranspose1d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, w_init=XavierNormal(   scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#

One-dimensional transposed convolution layer (also known as deconvolution).

Applies a 1D transposed convolution over an input signal. Transposed convolution is used for upsampling, reversing the spatial transformation of a regular convolution. It’s commonly used in autoencoders, GANs, and semantic segmentation networks.

The input should be a 3D array with the shape of [B, L, C] where B is batch size, L is the sequence length, and C is the number of input channels (channels-last format).

Parameters:
  • in_size (Sequence[int]) – The input shape without the batch dimension. For ConvTranspose1d: (L, C) where L is the sequence length and C is the number of input channels.

  • out_channels (int) – The number of output channels (feature maps) produced by the transposed convolution.

  • kernel_size (int | Tuple[int, ...]) – The shape of the convolutional kernel. For 1D, can be an integer or a single-element tuple.

  • stride (int | Tuple[int, ...]) – The stride of the transposed convolution. Larger strides produce larger output sizes, which is the opposite behavior of regular convolution. Default: 1.

  • padding (str | int | Tuple[int, int] | Sequence[Tuple[int, int]]) –

    The padding strategy. Options:

    • ’SAME’: output length approximately equals input_length * stride

    • ’VALID’: no padding, maximum output size

    • int: symmetric padding

    • (pad_before, pad_after): explicit padding for the sequence dimension

    Default: ‘SAME’.

  • lhs_dilation (int | Tuple[int, ...]) – The dilation factor for the input. For transposed convolution, this is typically set equal to stride internally. Default: 1.

  • rhs_dilation (int | Tuple[int, ...]) – The dilation factor for the kernel. Increases the receptive field without increasing parameters by inserting zeros between kernel elements. Default: 1.

  • groups (int) – Number of groups for grouped transposed convolution. Both in_channels and out_channels must be divisible by groups. Default: 1.

  • w_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity) – The initializer for the convolutional kernel weights. Default: XavierNormal().

  • b_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – The initializer for the bias. If None, no bias is added. Default: None.

  • w_mask (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – An optional mask applied to the weights during forward pass. Default: None.

  • channel_first (bool) – If True, uses channels-first format (e.g., [B, C, L]). If False, uses channels-last format (e.g., [B, L, C]). Default: False (channels-last, JAX convention).

  • name (str) – The name of the module. Default: None.

  • param_type (type) – The type of parameter state to use. Default: ParamState.

in_size#

The input shape (L, C) without batch dimension.

Type:

tuple of int

out_size#

The output shape (L_out, out_channels) without batch dimension.

Type:

tuple of int

in_channels#

Number of input channels.

Type:

int

out_channels#

Number of output channels.

Type:

int

kernel_size#

Size of the convolving kernel.

Type:

tuple of int

weight#

The learnable weights (and bias if specified) of the module.

Type:

ParamState

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a 1D transposed convolution layer for upsampling
>>> conv_transpose = brainstate.nn.ConvTranspose1d(
...     in_size=(28, 16),
...     out_channels=8,
...     kernel_size=4,
...     stride=2
... )
>>>
>>> # Apply to input: batch_size=2, length=28, channels=16
>>> x = jnp.ones((2, 28, 16))
>>> y = conv_transpose(x)
>>> print(y.shape)  # Output will be upsampled
>>>
>>> # Without batch dimension
>>> x_single = jnp.ones((28, 16))
>>> y_single = conv_transpose(x_single)
>>>
>>> # Channels-first format (PyTorch style)
>>> conv_transpose = brainstate.nn.ConvTranspose1d(
...     in_size=(16, 28),
...     out_channels=8,
...     kernel_size=4,
...     stride=2,
...     channel_first=True
... )
>>> x = jnp.ones((2, 16, 28))
>>> y = conv_transpose(x)

Notes

Output dimensions:

Unlike regular convolution, transposed convolution increases spatial dimensions. With stride > 1, the output is larger than the input:

  • output_length ≈ input_length * stride (depends on padding and kernel size)

Relationship to regular convolution:

Transposed convolution performs the gradient computation of a regular convolution with respect to its input. It’s sometimes called “deconvolution” but this term is mathematically imprecise.

Common use cases:

  • Upsampling in encoder-decoder architectures

  • Generative models (GANs, VAEs)

  • Semantic segmentation (U-Net, FCN)

  • Super-resolution networks

Comparison with PyTorch:

  • PyTorch uses channels-first by default; BrainState uses channels-last

  • Set channel_first=True for PyTorch-compatible format

  • PyTorch’s output_padding is handled through padding parameter