ConvTranspose3d#

class brainstate.nn.ConvTranspose3d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, w_init=XavierNormal(   scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#

Three-dimensional transposed convolution layer (also known as deconvolution).

Applies a 3D transposed convolution over an input signal. Used for upsampling 3D feature maps in video generation, 3D segmentation, and volumetric reconstruction.

The input should be a 5D array with the shape of [B, H, W, D, C] where B is batch size, H is height, W is width, D is depth, and C is the number of input channels (channels-last format).

Parameters:
  • in_size (Sequence[int]) – The input shape without the batch dimension. For ConvTranspose3d: (H, W, D, C) where H is height, W is width, D is depth, and C is the number of input channels.

  • out_channels (int) – The number of output channels (feature maps) produced by the transposed convolution.

  • kernel_size (int | Tuple[int, ...]) –

    The shape of the convolutional kernel. Can be:

    • An integer (e.g., 4): creates a cubic kernel (4, 4, 4)

    • A tuple of three integers (e.g., (4, 4, 4)): creates a (height, width, depth) kernel

  • stride (int | Tuple[int, ...]) –

    The stride of the transposed convolution. Controls the upsampling factor. Can be:

    • An integer: same stride for all dimensions

    • A tuple of three integers: (stride_h, stride_w, stride_d)

    Larger strides produce larger outputs. Default: 1.

  • padding (str | int | Tuple[int, int] | Sequence[Tuple[int, int]]) –

    The padding strategy. Options:

    • ’SAME’: output size approximately equals input_size * stride

    • ’VALID’: no padding, maximum output size

    • int: same symmetric padding for all dimensions

    • (pad_h, pad_w, pad_d): different padding for each dimension

    • [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit

    Default: ‘SAME’.

  • lhs_dilation (int | Tuple[int, ...]) – The dilation factor for the input. For transposed convolution, this is typically set equal to stride internally. Default: 1.

  • rhs_dilation (int | Tuple[int, ...]) – The dilation factor for the kernel. Increases the receptive field without increasing parameters. Default: 1.

  • groups (int) – Number of groups for grouped transposed convolution. Must divide both in_channels and out_channels. Useful for reducing computational cost in 3D. Default: 1.

  • w_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Weight initializer for the convolutional kernel. Default: XavierNormal().

  • b_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Bias initializer. If None, no bias term is added. Default: None.

  • w_mask (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Optional weight mask for structured sparsity. Default: None.

  • channel_first (bool) – If True, uses channels-first format (e.g., [B, C, H, W, D]). If False, uses channels-last format (e.g., [B, H, W, D, C]). Default: False (channels-last, JAX convention).

  • name (str) – Name identifier for this module instance. Default: None.

  • param_type (type) – The parameter state class to use. Default: ParamState.

in_size#

The input shape (H, W, D, C) without batch dimension.

Type:

tuple of int

out_size#

The output shape (H_out, W_out, D_out, out_channels) without batch dimension.

Type:

tuple of int

in_channels#

Number of input channels.

Type:

int

out_channels#

Number of output channels.

Type:

int

kernel_size#

Size of the convolving kernel (height, width, depth).

Type:

tuple of int

weight#

The learnable weights (and bias if specified) of the module.

Type:

ParamState

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a 3D transposed convolution for video upsampling
>>> conv_transpose = brainstate.nn.ConvTranspose3d(
...     in_size=(8, 16, 16, 64),
...     out_channels=32,
...     kernel_size=4,
...     stride=2
... )
>>>
>>> # Apply to input: batch_size=4, frames=8, height=16, width=16, channels=64
>>> x = jnp.ones((4, 8, 16, 16, 64))
>>> y = conv_transpose(x)
>>> print(y.shape)  # Output will be approximately (4, 16, 32, 32, 32)
>>>
>>> # Without batch dimension
>>> x_single = jnp.ones((8, 16, 16, 64))
>>> y_single = conv_transpose(x_single)
>>>
>>> # For medical imaging reconstruction
>>> decoder = brainstate.nn.ConvTranspose3d(
...     in_size=(16, 16, 16, 128),
...     out_channels=64,
...     kernel_size=(4, 4, 4),
...     stride=2,
...     padding='SAME',
...     b_init=braintools.init.Constant(0.0)
... )
>>>
>>> # Channels-first format (PyTorch style)
>>> conv_transpose = brainstate.nn.ConvTranspose3d(
...     in_size=(64, 8, 16, 16),
...     out_channels=32,
...     kernel_size=4,
...     stride=2,
...     channel_first=True
... )
>>> x = jnp.ones((4, 64, 8, 16, 16))
>>> y = conv_transpose(x)

Notes

Output dimensions:

Transposed convolution increases spatial dimensions:

  • output_size ≈ input_size * stride (exact size depends on padding and kernel size)

  • ‘SAME’ padding: output_size = input_size * stride

  • ‘VALID’ padding: output_size = input_size * stride + max(kernel_size - stride, 0)

Computational considerations:

3D transposed convolutions are very computationally expensive. Consider:

  • Using grouped convolutions (groups > 1) to reduce parameters

  • Smaller kernel sizes

  • Progressive upsampling (multiple layers with stride=2)

  • Separable convolutions for large-scale applications

Common use cases:

  • Video generation and prediction

  • 3D medical image segmentation (U-Net 3D)

  • Volumetric reconstruction

  • 3D super-resolution

  • Video frame interpolation

  • 3D VAE decoders

Comparison with PyTorch:

  • PyTorch uses channels-first by default; BrainState uses channels-last

  • Set channel_first=True for PyTorch-compatible format

  • Kernel shape convention differs between frameworks

  • PyTorch’s output_padding parameter is handled through padding here

Tips:

  • Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)

  • Group normalization often works better than batch normalization for 3D

  • Consider using smaller batch sizes due to memory constraints

  • Progressive upsampling (2x at a time) is more stable than large strides