Regularization#

class brainstate.nn.Regularization(fit_hyper=False)#

Abstract base class for parameter regularization.

Provides the interface for implementing regularization terms that can be added to the training loss. Subclasses must implement loss, sample_init, and reset_value methods.

Parameters:

fit_hyper (bool) – Whether to optimize the hyperparameters of the regularization as trainable parameters. Default is False.

fit_hyper#

Whether hyperparameters are trainable.

Type:

bool

Notes

Regularization can be used with the Param class to add regularization terms to the training loss.

loss(value)[source]#

Calculate regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values to compute regularization for.

Returns:

Scalar regularization loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return the reset value (e.g., prior mean).

Returns:

Value to reset the parameter to.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample initial value from the regularization’s implied prior distribution.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the parameter to initialize.

Returns:

Sampled initial value.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity