ScaledWSConv1d#

class brainstate.nn.ScaledWSConv1d(in_size, out_channels, kernel_size, stride=1, padding='SAME', lhs_dilation=1, rhs_dilation=1, groups=1, ws_gain=True, eps=0.0001, w_init=XavierNormal(   scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, w_mask=None, channel_first=False, name=None, param_type=<class 'brainstate.ParamState'>)#

One-dimensional convolution with weight standardization.

This layer applies weight standardization to the convolutional kernel before performing the convolution operation. Weight standardization normalizes the weights to have zero mean and unit variance, which can accelerate training and improve model performance, especially when combined with group normalization.

The input should be a 3D array with the shape of [B, L, C] where B is batch size, L is the sequence length, and C is the number of input channels.

Parameters:
  • in_size (Sequence[int]) – The input shape without the batch dimension. For Conv1d: (L, C) where L is the sequence length and C is the number of input channels. This argument is important as it is used to evaluate the output shape.

  • out_channels (int) – The number of output channels (also called filters or feature maps). These determine the depth of the output feature map.

  • kernel_size (int | Tuple[int, ...]) –

    The shape of the convolutional kernel. For 1D convolution, can be:

    • An integer (e.g., 5): creates a kernel of size 5

    • A tuple with one integer (e.g., (5,)): equivalent to the above

  • stride (int | Tuple[int, ...]) – The stride of the convolution. Controls how much the kernel moves at each step. Default: 1.

  • padding (str | int | Tuple[int, int] | Sequence[Tuple[int, int]]) –

    The padding strategy. Options:

    • ’SAME’: output length equals input length when stride=1

    • ’VALID’: no padding, output length reduced by kernel size

    • int: symmetric padding

    • (pad_before, pad_after): explicit padding for the sequence dimension

    Default: ‘SAME’.

  • lhs_dilation (int | Tuple[int, ...]) – The dilation factor for the input (left-hand side). Controls spacing between input elements. A value > 1 inserts zeros between input elements, equivalent to transposed convolution. Default: 1.

  • rhs_dilation (int | Tuple[int, ...]) – The dilation factor for the kernel (right-hand side). Also known as atrous convolution or dilated convolution. Increases the receptive field without increasing parameters by inserting zeros between kernel elements. Useful for capturing long-range dependencies. Default: 1.

  • groups (int) –

    Number of groups for grouped convolution. Must divide both in_channels and out_channels.

    • groups=1: standard convolution (all-to-all connections)

    • groups>1: grouped convolution (reduces parameters by factor of groups)

    • groups=in_channels: depthwise convolution (each input channel convolved separately)

    Default: 1.

  • w_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity) –

    Weight initializer for the convolutional kernel. Can be:

    • An initializer instance (e.g., braintools.init.XavierNormal())

    • A callable that returns an array given a shape

    • A direct array matching the kernel shape

    Default: XavierNormal().

  • b_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Bias initializer. If None, no bias term is added to the output. Default: None.

  • ws_gain (bool) – Whether to include a learnable per-channel gain parameter in weight standardization. When True, adds a scaling factor that can be learned during training, improving model expressiveness. Recommended for most applications. Default: True.

  • eps (float) – Small constant for numerical stability in weight standardization. Prevents division by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5. Default: 1e-4.

  • w_mask (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Optional weight mask for structured sparsity or custom connectivity. The mask is element-wise multiplied with the standardized kernel weights during the forward pass. Default: None.

  • name (str) – Name identifier for this module instance. Default: None.

  • param_type (type) – The parameter state class to use for managing learnable parameters. Default: ParamState.

in_size#

The input shape (L, C) without batch dimension.

Type:

tuple of int

out_size#

The output shape (L_out, out_channels) without batch dimension.

Type:

tuple of int

in_channels#

Number of input channels.

Type:

int

out_channels#

Number of output channels.

Type:

int

kernel_size#

Size of the convolving kernel.

Type:

tuple of int

weight#

The learnable weights (and bias if specified) of the module.

Type:

ParamState

eps#

Small constant for numerical stability in weight standardization.

Type:

float

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a 1D convolution with weight standardization
>>> conv = brainstate.nn.ScaledWSConv1d(
...     in_size=(100, 16),
...     out_channels=32,
...     kernel_size=5
... )
>>>
>>> # Apply to input
>>> x = jnp.ones((4, 100, 16))
>>> y = conv(x)
>>> print(y.shape)  # (4, 100, 32)
>>>
>>> # With custom epsilon and no gain
>>> conv = brainstate.nn.ScaledWSConv1d(
...     in_size=(50, 8),
...     out_channels=16,
...     kernel_size=3,
...     ws_gain=False,
...     eps=1e-5
... )

Notes

Weight standardization formula:

Weight standardization reparameterizes the convolutional weights as:

\[\hat{W} = g \cdot \frac{W - \mu_W}{\sigma_W + \epsilon}\]

where \(\mu_W\) and \(\sigma_W\) are the mean and standard deviation of the weights, \(g\) is a learnable gain parameter (if ws_gain=True), and \(\epsilon\) is a small constant for numerical stability.

When to use:

This technique is particularly effective when used with Group Normalization instead of Batch Normalization, as it reduces the dependence on batch statistics.

References