ConvTranspose2d#

class brainstate.nn.ConvTranspose2d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, w_init=XavierNormal(   scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#

Two-dimensional transposed convolution layer (also known as deconvolution).

Applies a 2D transposed convolution over an input signal. Transposed convolution is the gradient of a regular convolution with respect to its input, commonly used for upsampling feature maps in encoder-decoder architectures, GANs, and segmentation.

The input should be a 4D array with the shape of [B, H, W, C] where B is batch size, H is height, W is width, and C is the number of input channels (channels-last format).

Parameters:
  • in_size (Sequence[int]) – The input shape without the batch dimension. For ConvTranspose2d: (H, W, C) where H is height, W is width, and C is the number of input channels.

  • out_channels (int) – The number of output channels (feature maps) produced by the transposed convolution.

  • kernel_size (int | Tuple[int, ...]) –

    The shape of the convolutional kernel. Can be:

    • An integer (e.g., 4): creates a square kernel (4, 4)

    • A tuple of two integers (e.g., (4, 4)): creates a (height, width) kernel

  • stride (int | Tuple[int, ...]) –

    The stride of the transposed convolution. Controls the upsampling factor. Can be:

    • An integer: same stride for both dimensions

    • A tuple of two integers: (stride_height, stride_width)

    Larger strides produce larger outputs. Default: 1.

  • padding (str | int | Tuple[int, int] | Sequence[Tuple[int, int]]) –

    The padding strategy. Options:

    • ’SAME’: output size approximately equals input_size * stride

    • ’VALID’: no padding, maximum output size

    • int: same symmetric padding for all dimensions

    • (pad_h, pad_w): different padding for each dimension

    • [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding

    Default: ‘SAME’.

  • lhs_dilation (int | Tuple[int, ...]) – The dilation factor for the input. For transposed convolution, this is typically set equal to stride internally. Default: 1.

  • rhs_dilation (int | Tuple[int, ...]) – The dilation factor for the kernel. Increases the receptive field without increasing parameters by inserting zeros between kernel elements. Default: 1.

  • groups (int) – Number of groups for grouped transposed convolution. Must divide both in_channels and out_channels. Default: 1.

  • w_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Weight initializer for the convolutional kernel. Default: XavierNormal().

  • b_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Bias initializer. If None, no bias term is added. Default: None.

  • w_mask (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Optional weight mask for structured sparsity. Default: None.

  • channel_first (bool) – If True, uses channels-first format (e.g., [B, C, H, W]). If False, uses channels-last format (e.g., [B, H, W, C]). Default: False (channels-last, JAX convention).

  • name (str) – Name identifier for this module instance. Default: None.

  • param_type (type) – The parameter state class to use. Default: ParamState.

in_size#

The input shape (H, W, C) without batch dimension.

Type:

tuple of int

out_size#

The output shape (H_out, W_out, out_channels) without batch dimension.

Type:

tuple of int

in_channels#

Number of input channels.

Type:

int

out_channels#

Number of output channels.

Type:

int

kernel_size#

Size of the convolving kernel (height, width).

Type:

tuple of int

weight#

The learnable weights (and bias if specified) of the module.

Type:

ParamState

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a 2D transposed convolution for upsampling
>>> conv_transpose = brainstate.nn.ConvTranspose2d(
...     in_size=(32, 32, 64),
...     out_channels=32,
...     kernel_size=4,
...     stride=2
... )
>>>
>>> # Apply to input: batch_size=8, height=32, width=32, channels=64
>>> x = jnp.ones((8, 32, 32, 64))
>>> y = conv_transpose(x)
>>> print(y.shape)  # Output will be approximately (8, 64, 64, 32)
>>>
>>> # Without batch dimension
>>> x_single = jnp.ones((32, 32, 64))
>>> y_single = conv_transpose(x_single)
>>>
>>> # Decoder in autoencoder (upsampling path)
>>> decoder = brainstate.nn.ConvTranspose2d(
...     in_size=(16, 16, 128),
...     out_channels=64,
...     kernel_size=4,
...     stride=2,
...     padding='SAME',
...     b_init=braintools.init.Constant(0.0)
... )
>>>
>>> # Channels-first format (PyTorch style)
>>> conv_transpose = brainstate.nn.ConvTranspose2d(
...     in_size=(64, 32, 32),
...     out_channels=32,
...     kernel_size=4,
...     stride=2,
...     channel_first=True
... )
>>> x = jnp.ones((8, 64, 32, 32))
>>> y = conv_transpose(x)

Notes

Output dimensions:

Transposed convolution increases spatial dimensions, with the upsampling factor primarily controlled by stride:

  • output_size ≈ input_size * stride (exact size depends on padding and kernel size)

  • ‘SAME’ padding: output_size = input_size * stride

  • ‘VALID’ padding: output_size = input_size * stride + max(kernel_size - stride, 0)

Relationship to regular convolution:

Transposed convolution is the backward pass of a regular convolution. If a regular convolution reduces spatial dimensions from X to Y, a transposed convolution with the same parameters increases dimensions from Y back to approximately X.

Common use cases:

  • Image segmentation (U-Net, SegNet, FCN)

  • Image-to-image translation (pix2pix, CycleGAN)

  • Generative models (DCGAN, VAE decoders)

  • Super-resolution networks

  • Autoencoders (decoder path)

Comparison with PyTorch:

  • PyTorch uses channels-first by default; BrainState uses channels-last

  • Set channel_first=True for PyTorch-compatible format

  • Kernel shape convention: PyTorch stores (C_in, C_out, H, W), BrainState uses (H, W, C_out, C_in)

  • PyTorch’s output_padding parameter controls output size; use padding parameter here

Tips:

  • Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)

  • Initialize with bilinear upsampling weights for better convergence in segmentation

  • Combine with batch normalization or group normalization for stable training