ConvTranspose2d#
- class brainstate.nn.ConvTranspose2d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, w_init=XavierNormal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#
Two-dimensional transposed convolution layer (also known as deconvolution).
Applies a 2D transposed convolution over an input signal. Transposed convolution is the gradient of a regular convolution with respect to its input, commonly used for upsampling feature maps in encoder-decoder architectures, GANs, and segmentation.
The input should be a 4D array with the shape of
[B, H, W, C]where B is batch size, H is height, W is width, and C is the number of input channels (channels-last format).- Parameters:
in_size (
Sequence[int]) – The input shape without the batch dimension. For ConvTranspose2d: (H, W, C) where H is height, W is width, and C is the number of input channels.out_channels (
int) – The number of output channels (feature maps) produced by the transposed convolution.kernel_size (
int|Tuple[int,...]) –The shape of the convolutional kernel. Can be:
An integer (e.g., 4): creates a square kernel (4, 4)
A tuple of two integers (e.g., (4, 4)): creates a (height, width) kernel
stride (
int|Tuple[int,...]) –The stride of the transposed convolution. Controls the upsampling factor. Can be:
An integer: same stride for both dimensions
A tuple of two integers: (stride_height, stride_width)
Larger strides produce larger outputs. Default: 1.
padding (
str|int|Tuple[int,int] |Sequence[Tuple[int,int]]) –The padding strategy. Options:
’SAME’: output size approximately equals input_size * stride
’VALID’: no padding, maximum output size
int: same symmetric padding for all dimensions
(pad_h, pad_w): different padding for each dimension
[(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
Default: ‘SAME’.
lhs_dilation (
int|Tuple[int,...]) – The dilation factor for the input. For transposed convolution, this is typically set equal to stride internally. Default: 1.rhs_dilation (
int|Tuple[int,...]) – The dilation factor for the kernel. Increases the receptive field without increasing parameters by inserting zeros between kernel elements. Default: 1.groups (
int) – Number of groups for grouped transposed convolution. Must divide both in_channels and out_channels. Default: 1.w_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Weight initializer for the convolutional kernel. Default: XavierNormal().b_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Bias initializer. If None, no bias term is added. Default: None.w_mask (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Optional weight mask for structured sparsity. Default: None.channel_first (
bool) – If True, uses channels-first format (e.g., [B, C, H, W]). If False, uses channels-last format (e.g., [B, H, W, C]). Default: False (channels-last, JAX convention).name (
str) – Name identifier for this module instance. Default: None.param_type (
type) – The parameter state class to use. Default: ParamState.
- weight#
The learnable weights (and bias if specified) of the module.
- Type:
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Create a 2D transposed convolution for upsampling >>> conv_transpose = brainstate.nn.ConvTranspose2d( ... in_size=(32, 32, 64), ... out_channels=32, ... kernel_size=4, ... stride=2 ... ) >>> >>> # Apply to input: batch_size=8, height=32, width=32, channels=64 >>> x = jnp.ones((8, 32, 32, 64)) >>> y = conv_transpose(x) >>> print(y.shape) # Output will be approximately (8, 64, 64, 32) >>> >>> # Without batch dimension >>> x_single = jnp.ones((32, 32, 64)) >>> y_single = conv_transpose(x_single) >>> >>> # Decoder in autoencoder (upsampling path) >>> decoder = brainstate.nn.ConvTranspose2d( ... in_size=(16, 16, 128), ... out_channels=64, ... kernel_size=4, ... stride=2, ... padding='SAME', ... b_init=braintools.init.Constant(0.0) ... ) >>> >>> # Channels-first format (PyTorch style) >>> conv_transpose = brainstate.nn.ConvTranspose2d( ... in_size=(64, 32, 32), ... out_channels=32, ... kernel_size=4, ... stride=2, ... channel_first=True ... ) >>> x = jnp.ones((8, 64, 32, 32)) >>> y = conv_transpose(x)
Notes
Output dimensions:
Transposed convolution increases spatial dimensions, with the upsampling factor primarily controlled by stride:
output_size ≈ input_size * stride (exact size depends on padding and kernel size)
‘SAME’ padding: output_size = input_size * stride
‘VALID’ padding: output_size = input_size * stride + max(kernel_size - stride, 0)
Relationship to regular convolution:
Transposed convolution is the backward pass of a regular convolution. If a regular convolution reduces spatial dimensions from X to Y, a transposed convolution with the same parameters increases dimensions from Y back to approximately X.
Common use cases:
Image segmentation (U-Net, SegNet, FCN)
Image-to-image translation (pix2pix, CycleGAN)
Generative models (DCGAN, VAE decoders)
Super-resolution networks
Autoencoders (decoder path)
Comparison with PyTorch:
PyTorch uses channels-first by default; BrainState uses channels-last
Set channel_first=True for PyTorch-compatible format
Kernel shape convention: PyTorch stores (C_in, C_out, H, W), BrainState uses (H, W, C_out, C_in)
PyTorch’s output_padding parameter controls output size; use padding parameter here
Tips:
Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)
Initialize with bilinear upsampling weights for better convergence in segmentation
Combine with batch normalization or group normalization for stable training