ConvTranspose1d#
- class brainstate.nn.ConvTranspose1d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, w_init=XavierNormal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#
One-dimensional transposed convolution layer (also known as deconvolution).
Applies a 1D transposed convolution over an input signal. Transposed convolution is used for upsampling, reversing the spatial transformation of a regular convolution. It’s commonly used in autoencoders, GANs, and semantic segmentation networks.
The input should be a 3D array with the shape of
[B, L, C]where B is batch size, L is the sequence length, and C is the number of input channels (channels-last format).- Parameters:
in_size (
Sequence[int]) – The input shape without the batch dimension. For ConvTranspose1d: (L, C) where L is the sequence length and C is the number of input channels.out_channels (
int) – The number of output channels (feature maps) produced by the transposed convolution.kernel_size (
int|Tuple[int,...]) – The shape of the convolutional kernel. For 1D, can be an integer or a single-element tuple.stride (
int|Tuple[int,...]) – The stride of the transposed convolution. Larger strides produce larger output sizes, which is the opposite behavior of regular convolution. Default: 1.padding (
str|int|Tuple[int,int] |Sequence[Tuple[int,int]]) –The padding strategy. Options:
’SAME’: output length approximately equals input_length * stride
’VALID’: no padding, maximum output size
int: symmetric padding
(pad_before, pad_after): explicit padding for the sequence dimension
Default: ‘SAME’.
lhs_dilation (
int|Tuple[int,...]) – The dilation factor for the input. For transposed convolution, this is typically set equal to stride internally. Default: 1.rhs_dilation (
int|Tuple[int,...]) – The dilation factor for the kernel. Increases the receptive field without increasing parameters by inserting zeros between kernel elements. Default: 1.groups (
int) – Number of groups for grouped transposed convolution. Both in_channels and out_channels must be divisible by groups. Default: 1.w_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity) – The initializer for the convolutional kernel weights. Default: XavierNormal().b_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – The initializer for the bias. If None, no bias is added. Default: None.w_mask (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – An optional mask applied to the weights during forward pass. Default: None.channel_first (
bool) – If True, uses channels-first format (e.g., [B, C, L]). If False, uses channels-last format (e.g., [B, L, C]). Default: False (channels-last, JAX convention).name (
str) – The name of the module. Default: None.param_type (
type) – The type of parameter state to use. Default: ParamState.
- weight#
The learnable weights (and bias if specified) of the module.
- Type:
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Create a 1D transposed convolution layer for upsampling >>> conv_transpose = brainstate.nn.ConvTranspose1d( ... in_size=(28, 16), ... out_channels=8, ... kernel_size=4, ... stride=2 ... ) >>> >>> # Apply to input: batch_size=2, length=28, channels=16 >>> x = jnp.ones((2, 28, 16)) >>> y = conv_transpose(x) >>> print(y.shape) # Output will be upsampled >>> >>> # Without batch dimension >>> x_single = jnp.ones((28, 16)) >>> y_single = conv_transpose(x_single) >>> >>> # Channels-first format (PyTorch style) >>> conv_transpose = brainstate.nn.ConvTranspose1d( ... in_size=(16, 28), ... out_channels=8, ... kernel_size=4, ... stride=2, ... channel_first=True ... ) >>> x = jnp.ones((2, 16, 28)) >>> y = conv_transpose(x)
Notes
Output dimensions:
Unlike regular convolution, transposed convolution increases spatial dimensions. With stride > 1, the output is larger than the input:
output_length ≈ input_length * stride (depends on padding and kernel size)
Relationship to regular convolution:
Transposed convolution performs the gradient computation of a regular convolution with respect to its input. It’s sometimes called “deconvolution” but this term is mathematically imprecise.
Common use cases:
Upsampling in encoder-decoder architectures
Generative models (GANs, VAEs)
Semantic segmentation (U-Net, FCN)
Super-resolution networks
Comparison with PyTorch:
PyTorch uses channels-first by default; BrainState uses channels-last
Set channel_first=True for PyTorch-compatible format
PyTorch’s output_padding is handled through padding parameter