Conv1d#

class brainstate.nn.Conv1d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, w_init=XavierNormal(   scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#

One-dimensional convolution layer.

Applies a 1D convolution over an input signal composed of several input planes. The input should be a 3D array with the shape of [B, L, C] where B is batch size, L is the sequence length, and C is the number of input channels.

This layer creates a convolution kernel that is convolved with the layer input over a single spatial dimension to produce a tensor of outputs.

Parameters:
  • in_size (Sequence[int]) – The input shape without the batch dimension. This argument is important as it is used to evaluate the output shape. For Conv1d: (L, C), Conv2d: (H, W, C), Conv3d: (H, W, D, C).

  • out_channels (int) – The number of output channels (also called filters or feature maps).

  • kernel_size (int | Tuple[int, ...]) – The shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer. For 2D and 3D convolutions, it should be a tuple of integers or a single integer (which will be replicated for all spatial dimensions).

  • stride (int | Tuple[int, ...]) – The stride of the convolution. An integer or a sequence of n integers, representing the inter-window strides along each spatial dimension. Default: 1.

  • padding (str | int | Tuple[int, int] | Sequence[Tuple[int, int]]) –

    The padding strategy. Can be:

    • ’SAME’: pads the input so the output has the same shape as input when stride=1

    • ’VALID’: no padding

    • int: symmetric padding applied to all spatial dimensions

    • tuple of (low, high): padding for each dimension

    • sequence of tuples: explicit padding for each spatial dimension

    Default: ‘SAME’.

  • lhs_dilation (int | Tuple[int, ...]) – The dilation factor for the input. An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs. Convolution with input dilation d is equivalent to transposed convolution with stride d. Default: 1.

  • rhs_dilation (int | Tuple[int, ...]) – The dilation factor for the kernel. An integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as ‘atrous convolution’, which increases the receptive field without increasing the number of parameters. Default: 1.

  • groups (int) – Number of groups for grouped convolution. Controls the connections between inputs and outputs. Both in_channels and out_channels must be divisible by groups. When groups=1 (default), all inputs are convolved to all outputs. When groups>1, the input and output channels are divided into groups, and each group is convolved independently. When groups=in_channels, this becomes a depthwise convolution. Default: 1.

  • w_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity) – The initializer for the convolutional kernel weights. Can be an initializer instance or a direct array. Default: XavierNormal().

  • b_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – The initializer for the bias. If None, no bias is added. Default: None.

  • w_mask (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – An optional mask applied to the weights during forward pass. Useful for implementing structured sparsity or custom connectivity patterns. Default: None.

  • name (str) – The name of the module. Default: None.

  • param_type (type) – The type of parameter state to use. Default: ParamState.

in_size#

The input shape (L, C) without batch dimension.

Type:

tuple of int

out_size#

The output shape (L_out, out_channels) without batch dimension.

Type:

tuple of int

in_channels#

Number of input channels.

Type:

int

out_channels#

Number of output channels.

Type:

int

kernel_size#

Size of the convolving kernel.

Type:

tuple of int

weight#

The learnable weights (and bias if specified) of the module.

Type:

ParamState

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a 1D convolution layer
>>> conv = brainstate.nn.Conv1d(in_size=(28, 3), out_channels=16, kernel_size=5)
>>>
>>> # Apply to input: batch_size=2, length=28, channels=3
>>> x = jnp.ones((2, 28, 3))
>>> y = conv(x)
>>> print(y.shape)  # (2, 28, 16) with 'SAME' padding
>>>
>>> # Without batch dimension
>>> x_single = jnp.ones((28, 3))
>>> y_single = conv(x_single)
>>> print(y_single.shape)  # (28, 16)
>>>
>>> # With custom parameters
>>> conv = brainstate.nn.Conv1d(
...     in_size=(100, 8),
...     out_channels=32,
...     kernel_size=3,
...     stride=2,
...     padding='VALID',
...     b_init=braintools.init.ZeroInit()
... )

Notes

Output dimensions:

The output shape depends on the padding mode:

  • ‘SAME’: output length = ceil(input_length / stride)

  • ‘VALID’: output length = ceil((input_length - kernel_size + 1) / stride)

Grouped convolution:

When groups > 1, the convolution becomes a grouped convolution where input and output channels are divided into groups, reducing computational cost.