Conv2d#
- class brainstate.nn.Conv2d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, w_init=XavierNormal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#
Two-dimensional convolution layer.
Applies a 2D convolution over an input signal composed of several input planes. The input should be a 4D array with the shape of
[B, H, W, C]where B is batch size, H is height, W is width, and C is the number of input channels (channels-last format).This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. It is commonly used in computer vision tasks.
- Parameters:
in_size (
Sequence[int]) – The input shape without the batch dimension. For Conv2d: (H, W, C) where H is height, W is width, and C is the number of input channels. This argument is important as it is used to evaluate the output shape.out_channels (
int) – The number of output channels (also called filters or feature maps). These determine the depth of the output feature map.kernel_size (
int|Tuple[int,...]) –The shape of the convolutional kernel. Can be:
An integer (e.g., 3): creates a square kernel (3, 3)
A tuple of two integers (e.g., (3, 5)): creates a (height, width) kernel
stride (
int|Tuple[int,...]) –The stride of the convolution. Controls how much the kernel moves at each step. Can be:
An integer: same stride for both dimensions
A tuple of two integers: (stride_height, stride_width)
Default: 1.
padding (
str|int|Tuple[int,int] |Sequence[Tuple[int,int]]) –The padding strategy. Options:
’SAME’: output spatial size equals input size when stride=1
’VALID’: no padding, output size reduced by kernel size
int: same symmetric padding for all dimensions
(pad_h, pad_w): different padding for each dimension
[(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
Default: ‘SAME’.
lhs_dilation (
int|Tuple[int,...]) – The dilation factor for the input (left-hand side). Controls spacing between input elements. A value > 1 inserts zeros between input elements, equivalent to transposed convolution. Default: 1.rhs_dilation (
int|Tuple[int,...]) – The dilation factor for the kernel (right-hand side). Also known as atrous convolution or dilated convolution. Increases the receptive field without increasing parameters by inserting zeros between kernel elements. Useful for capturing multi-scale context. Default: 1.groups (
int) –Number of groups for grouped convolution. Must divide both in_channels and out_channels.
groups=1: standard convolution (all-to-all connections)
groups>1: grouped convolution (reduces parameters by factor of groups)
groups=in_channels: depthwise convolution (each input channel convolved separately)
Default: 1.
w_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity) –Weight initializer for the convolutional kernel. Can be:
An initializer instance (e.g., braintools.init.XavierNormal())
A callable that returns an array given a shape
A direct array matching the kernel shape
Default: XavierNormal().
b_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Bias initializer. If None, no bias term is added to the output. Default: None.w_mask (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Optional weight mask for structured sparsity or custom connectivity. The mask is element-wise multiplied with the kernel weights during the forward pass. Default: None.name (
str) – Name identifier for this module instance. Default: None.param_type (
type) – The parameter state class to use for managing learnable parameters. Default: ParamState.
- weight#
The learnable weights (and bias if specified) of the module.
- Type:
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Create a 2D convolution layer >>> conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3) >>> >>> # Apply to input: batch_size=8, height=32, width=32, channels=3 >>> x = jnp.ones((8, 32, 32, 3)) >>> y = conv(x) >>> print(y.shape) # (8, 32, 32, 64) with 'SAME' padding >>> >>> # Without batch dimension >>> x_single = jnp.ones((32, 32, 3)) >>> y_single = conv(x_single) >>> print(y_single.shape) # (32, 32, 64) >>> >>> # With custom kernel size and stride >>> conv = brainstate.nn.Conv2d( ... in_size=(224, 224, 3), ... out_channels=128, ... kernel_size=(5, 5), ... stride=2, ... padding='VALID' ... ) >>> >>> # Depthwise convolution (groups = in_channels) >>> conv = brainstate.nn.Conv2d( ... in_size=(64, 64, 32), ... out_channels=32, ... kernel_size=3, ... groups=32 ... )
Notes
Output dimensions:
The output spatial dimensions depend on the padding mode:
‘SAME’: output_size = ceil(input_size / stride)
‘VALID’: output_size = ceil((input_size - kernel_size + 1) / stride)
Grouped convolution:
When groups > 1, the input and output channels are divided into groups. Each group is convolved independently, which can significantly reduce computational cost while maintaining representational power.