ConvTranspose3d#
- class brainstate.nn.ConvTranspose3d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, w_init=XavierNormal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#
Three-dimensional transposed convolution layer (also known as deconvolution).
Applies a 3D transposed convolution over an input signal. Used for upsampling 3D feature maps in video generation, 3D segmentation, and volumetric reconstruction.
The input should be a 5D array with the shape of
[B, H, W, D, C]where B is batch size, H is height, W is width, D is depth, and C is the number of input channels (channels-last format).- Parameters:
in_size (
Sequence[int]) – The input shape without the batch dimension. For ConvTranspose3d: (H, W, D, C) where H is height, W is width, D is depth, and C is the number of input channels.out_channels (
int) – The number of output channels (feature maps) produced by the transposed convolution.kernel_size (
int|Tuple[int,...]) –The shape of the convolutional kernel. Can be:
An integer (e.g., 4): creates a cubic kernel (4, 4, 4)
A tuple of three integers (e.g., (4, 4, 4)): creates a (height, width, depth) kernel
stride (
int|Tuple[int,...]) –The stride of the transposed convolution. Controls the upsampling factor. Can be:
An integer: same stride for all dimensions
A tuple of three integers: (stride_h, stride_w, stride_d)
Larger strides produce larger outputs. Default: 1.
padding (
str|int|Tuple[int,int] |Sequence[Tuple[int,int]]) –The padding strategy. Options:
’SAME’: output size approximately equals input_size * stride
’VALID’: no padding, maximum output size
int: same symmetric padding for all dimensions
(pad_h, pad_w, pad_d): different padding for each dimension
[(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit
Default: ‘SAME’.
lhs_dilation (
int|Tuple[int,...]) – The dilation factor for the input. For transposed convolution, this is typically set equal to stride internally. Default: 1.rhs_dilation (
int|Tuple[int,...]) – The dilation factor for the kernel. Increases the receptive field without increasing parameters. Default: 1.groups (
int) – Number of groups for grouped transposed convolution. Must divide both in_channels and out_channels. Useful for reducing computational cost in 3D. Default: 1.w_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Weight initializer for the convolutional kernel. Default: XavierNormal().b_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Bias initializer. If None, no bias term is added. Default: None.w_mask (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Optional weight mask for structured sparsity. Default: None.channel_first (
bool) – If True, uses channels-first format (e.g., [B, C, H, W, D]). If False, uses channels-last format (e.g., [B, H, W, D, C]). Default: False (channels-last, JAX convention).name (
str) – Name identifier for this module instance. Default: None.param_type (
type) – The parameter state class to use. Default: ParamState.
- out_size#
The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
- weight#
The learnable weights (and bias if specified) of the module.
- Type:
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Create a 3D transposed convolution for video upsampling >>> conv_transpose = brainstate.nn.ConvTranspose3d( ... in_size=(8, 16, 16, 64), ... out_channels=32, ... kernel_size=4, ... stride=2 ... ) >>> >>> # Apply to input: batch_size=4, frames=8, height=16, width=16, channels=64 >>> x = jnp.ones((4, 8, 16, 16, 64)) >>> y = conv_transpose(x) >>> print(y.shape) # Output will be approximately (4, 16, 32, 32, 32) >>> >>> # Without batch dimension >>> x_single = jnp.ones((8, 16, 16, 64)) >>> y_single = conv_transpose(x_single) >>> >>> # For medical imaging reconstruction >>> decoder = brainstate.nn.ConvTranspose3d( ... in_size=(16, 16, 16, 128), ... out_channels=64, ... kernel_size=(4, 4, 4), ... stride=2, ... padding='SAME', ... b_init=braintools.init.Constant(0.0) ... ) >>> >>> # Channels-first format (PyTorch style) >>> conv_transpose = brainstate.nn.ConvTranspose3d( ... in_size=(64, 8, 16, 16), ... out_channels=32, ... kernel_size=4, ... stride=2, ... channel_first=True ... ) >>> x = jnp.ones((4, 64, 8, 16, 16)) >>> y = conv_transpose(x)
Notes
Output dimensions:
Transposed convolution increases spatial dimensions:
output_size ≈ input_size * stride (exact size depends on padding and kernel size)
‘SAME’ padding: output_size = input_size * stride
‘VALID’ padding: output_size = input_size * stride + max(kernel_size - stride, 0)
Computational considerations:
3D transposed convolutions are very computationally expensive. Consider:
Using grouped convolutions (groups > 1) to reduce parameters
Smaller kernel sizes
Progressive upsampling (multiple layers with stride=2)
Separable convolutions for large-scale applications
Common use cases:
Video generation and prediction
3D medical image segmentation (U-Net 3D)
Volumetric reconstruction
3D super-resolution
Video frame interpolation
3D VAE decoders
Comparison with PyTorch:
PyTorch uses channels-first by default; BrainState uses channels-last
Set channel_first=True for PyTorch-compatible format
Kernel shape convention differs between frameworks
PyTorch’s output_padding parameter is handled through padding here
Tips:
Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)
Group normalization often works better than batch normalization for 3D
Consider using smaller batch sizes due to memory constraints
Progressive upsampling (2x at a time) is more stable than large strides