Conv2d#

class brainstate.nn.Conv2d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, w_init=XavierNormal(   scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#

Two-dimensional convolution layer.

Applies a 2D convolution over an input signal composed of several input planes. The input should be a 4D array with the shape of [B, H, W, C] where B is batch size, H is height, W is width, and C is the number of input channels (channels-last format).

This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. It is commonly used in computer vision tasks.

Parameters:
  • in_size (Sequence[int]) – The input shape without the batch dimension. For Conv2d: (H, W, C) where H is height, W is width, and C is the number of input channels. This argument is important as it is used to evaluate the output shape.

  • out_channels (int) – The number of output channels (also called filters or feature maps). These determine the depth of the output feature map.

  • kernel_size (int | Tuple[int, ...]) –

    The shape of the convolutional kernel. Can be:

    • An integer (e.g., 3): creates a square kernel (3, 3)

    • A tuple of two integers (e.g., (3, 5)): creates a (height, width) kernel

  • stride (int | Tuple[int, ...]) –

    The stride of the convolution. Controls how much the kernel moves at each step. Can be:

    • An integer: same stride for both dimensions

    • A tuple of two integers: (stride_height, stride_width)

    Default: 1.

  • padding (str | int | Tuple[int, int] | Sequence[Tuple[int, int]]) –

    The padding strategy. Options:

    • ’SAME’: output spatial size equals input size when stride=1

    • ’VALID’: no padding, output size reduced by kernel size

    • int: same symmetric padding for all dimensions

    • (pad_h, pad_w): different padding for each dimension

    • [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding

    Default: ‘SAME’.

  • lhs_dilation (int | Tuple[int, ...]) – The dilation factor for the input (left-hand side). Controls spacing between input elements. A value > 1 inserts zeros between input elements, equivalent to transposed convolution. Default: 1.

  • rhs_dilation (int | Tuple[int, ...]) – The dilation factor for the kernel (right-hand side). Also known as atrous convolution or dilated convolution. Increases the receptive field without increasing parameters by inserting zeros between kernel elements. Useful for capturing multi-scale context. Default: 1.

  • groups (int) –

    Number of groups for grouped convolution. Must divide both in_channels and out_channels.

    • groups=1: standard convolution (all-to-all connections)

    • groups>1: grouped convolution (reduces parameters by factor of groups)

    • groups=in_channels: depthwise convolution (each input channel convolved separately)

    Default: 1.

  • w_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity) –

    Weight initializer for the convolutional kernel. Can be:

    • An initializer instance (e.g., braintools.init.XavierNormal())

    • A callable that returns an array given a shape

    • A direct array matching the kernel shape

    Default: XavierNormal().

  • b_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Bias initializer. If None, no bias term is added to the output. Default: None.

  • w_mask (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Optional weight mask for structured sparsity or custom connectivity. The mask is element-wise multiplied with the kernel weights during the forward pass. Default: None.

  • name (str) – Name identifier for this module instance. Default: None.

  • param_type (type) – The parameter state class to use for managing learnable parameters. Default: ParamState.

in_size#

The input shape (H, W, C) without batch dimension.

Type:

tuple of int

out_size#

The output shape (H_out, W_out, out_channels) without batch dimension.

Type:

tuple of int

in_channels#

Number of input channels.

Type:

int

out_channels#

Number of output channels.

Type:

int

kernel_size#

Size of the convolving kernel (height, width).

Type:

tuple of int

weight#

The learnable weights (and bias if specified) of the module.

Type:

ParamState

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a 2D convolution layer
>>> conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
>>>
>>> # Apply to input: batch_size=8, height=32, width=32, channels=3
>>> x = jnp.ones((8, 32, 32, 3))
>>> y = conv(x)
>>> print(y.shape)  # (8, 32, 32, 64) with 'SAME' padding
>>>
>>> # Without batch dimension
>>> x_single = jnp.ones((32, 32, 3))
>>> y_single = conv(x_single)
>>> print(y_single.shape)  # (32, 32, 64)
>>>
>>> # With custom kernel size and stride
>>> conv = brainstate.nn.Conv2d(
...     in_size=(224, 224, 3),
...     out_channels=128,
...     kernel_size=(5, 5),
...     stride=2,
...     padding='VALID'
... )
>>>
>>> # Depthwise convolution (groups = in_channels)
>>> conv = brainstate.nn.Conv2d(
...     in_size=(64, 64, 32),
...     out_channels=32,
...     kernel_size=3,
...     groups=32
... )

Notes

Output dimensions:

The output spatial dimensions depend on the padding mode:

  • ‘SAME’: output_size = ceil(input_size / stride)

  • ‘VALID’: output_size = ceil((input_size - kernel_size + 1) / stride)

Grouped convolution:

When groups > 1, the input and output channels are divided into groups. Each group is convolved independently, which can significantly reduce computational cost while maintaining representational power.