Conv3d#
- class brainstate.nn.Conv3d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, w_init=XavierNormal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#
Three-dimensional convolution layer.
Applies a 3D convolution over an input signal composed of several input planes. The input should be a 5D array with the shape of
[B, H, W, D, C]where B is batch size, H is height, W is width, D is depth, and C is the number of input channels (channels-last format).This layer is commonly used for processing 3D data such as video sequences or volumetric medical imaging data.
- Parameters:
in_size (
Sequence[int]) – The input shape without the batch dimension. For Conv3d: (H, W, D, C) where H is height, W is width, D is depth, and C is the number of input channels. This argument is important as it is used to evaluate the output shape.out_channels (
int) – The number of output channels (also called filters or feature maps). These determine the depth of the output feature map.kernel_size (
int|Tuple[int,...]) –The shape of the convolutional kernel. Can be:
An integer (e.g., 3): creates a cubic kernel (3, 3, 3)
A tuple of three integers (e.g., (3, 5, 5)): creates a (height, width, depth) kernel
stride (
int|Tuple[int,...]) –The stride of the convolution. Controls how much the kernel moves at each step. Can be:
An integer: same stride for all dimensions
A tuple of three integers: (stride_h, stride_w, stride_d)
Default: 1.
padding (
str|int|Tuple[int,int] |Sequence[Tuple[int,int]]) –The padding strategy. Options:
’SAME’: output spatial size equals input size when stride=1
’VALID’: no padding, output size reduced by kernel size
int: same symmetric padding for all dimensions
(pad_h, pad_w, pad_d): different padding for each dimension
[(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit padding
Default: ‘SAME’.
lhs_dilation (
int|Tuple[int,...]) – The dilation factor for the input (left-hand side). Controls spacing between input elements. A value > 1 inserts zeros between input elements, equivalent to transposed convolution. Default: 1.rhs_dilation (
int|Tuple[int,...]) – The dilation factor for the kernel (right-hand side). Also known as atrous convolution or dilated convolution. Increases the receptive field without increasing parameters by inserting zeros between kernel elements. Particularly useful for 3D data to capture larger temporal/spatial context. Default: 1.groups (
int) –Number of groups for grouped convolution. Must divide both in_channels and out_channels.
groups=1: standard convolution (all-to-all connections)
groups>1: grouped convolution (significantly reduces parameters and computation for 3D)
groups=in_channels: depthwise convolution (each input channel convolved separately)
Default: 1.
w_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity) –Weight initializer for the convolutional kernel. Can be:
An initializer instance (e.g., braintools.init.XavierNormal())
A callable that returns an array given a shape
A direct array matching the kernel shape
Default: XavierNormal().
b_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Bias initializer. If None, no bias term is added to the output. Default: None.w_mask (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Optional weight mask for structured sparsity or custom connectivity. The mask is element-wise multiplied with the kernel weights during the forward pass. Default: None.name (
str) – Name identifier for this module instance. Default: None.param_type (
type) – The parameter state class to use for managing learnable parameters. Default: ParamState.
- out_size#
The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
- weight#
The learnable weights (and bias if specified) of the module.
- Type:
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Create a 3D convolution layer for video data >>> conv = brainstate.nn.Conv3d(in_size=(16, 64, 64, 3), out_channels=32, kernel_size=3) >>> >>> # Apply to input: batch_size=4, frames=16, height=64, width=64, channels=3 >>> x = jnp.ones((4, 16, 64, 64, 3)) >>> y = conv(x) >>> print(y.shape) # (4, 16, 64, 64, 32) with 'SAME' padding >>> >>> # Without batch dimension >>> x_single = jnp.ones((16, 64, 64, 3)) >>> y_single = conv(x_single) >>> print(y_single.shape) # (16, 64, 64, 32) >>> >>> # For medical imaging with custom parameters >>> conv = brainstate.nn.Conv3d( ... in_size=(32, 32, 32, 1), ... out_channels=64, ... kernel_size=(3, 3, 3), ... stride=2, ... padding='VALID', ... b_init=braintools.init.Constant(0.1) ... )
Notes
Output dimensions:
The output spatial dimensions depend on the padding mode:
‘SAME’: output_size = ceil(input_size / stride)
‘VALID’: output_size = ceil((input_size - kernel_size + 1) / stride)
Performance considerations:
3D convolutions are computationally expensive. Consider using:
Smaller kernel sizes
Grouped convolutions (groups > 1)
Separable convolutions for large-scale applications