TruncatedNormal

Contents

TruncatedNormal#

class braintools.init.TruncatedNormal(mean, std, low=None, high=None, unit=None)#

Truncated normal distribution initialization.

Generates values from a normal distribution truncated to specified bounds. Requires scipy to be installed.

Parameters:
  • mean (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Mean of the underlying normal distribution.

  • std (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Standard deviation of the underlying normal distribution.

  • low (Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Lower bound (default: -inf).

  • high (Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Upper bound (default: +inf).

Examples

>>> import numpy as np
>>> import brainunit as u
>>> from braintools.init import TruncatedNormal
>>>
>>> init = TruncatedNormal(
...     mean=0.5 * u.siemens,
...     std=0.2 * u.siemens,
...     low=0.0 * u.siemens,
...     high=1.0 * u.siemens
... )
>>> rng = np.random.default_rng(0)
>>> weights = init(1000, rng=rng)