truncated_normal#
- class brainstate.random.truncated_normal(lower, upper, size=None, loc=0.0, scale=1.0, key=None, dtype=None, check_valid=True)#
Sample truncated standard normal random values with given shape and dtype.
Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
Notes
This distribution is the normal distribution centered on
loc(default 0), with standard deviationscale(default 1), and clipped ata,bstandard deviations to the left, right (respectively) fromloc. Ifmyclip_aandmyclip_bare clip values in the sample space (as opposed to the number of standard deviations) then they can be converted to the required form according to:a, b = (myclip_a - loc) / scale, (myclip_b - loc) / scale
- Parameters:
lower (float, ndarray) – A float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with
upper.upper (float, ndarray) – A float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with
lower.loc (optional, float, ndarray) – Mean (“centre”) of the distribution before truncating. Note that the mean of the truncated distribution will not be exactly equal to
loc.size (
int|Sequence[int] |integer|Sequence[integer] |None) – A tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible withlowerandupper. The default (None) produces a result shape by broadcastinglowerandupper.loc – A float or array of floats representing the mean of the distribution. Default is 0.
scale (float, ndarray) – Standard deviation (spread or “width”) of the distribution. Must be non-negative. Default is 1.
dtype (
str|type[Any] |dtype|SupportsDType) – The float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).key (
int|Array|ndarray|None) – The key for the random number generator. If not given, the default random number generator is used.check_valid (
bool) – Whether to check the validity of the input parameters. Default is True.
- Returns:
out – A random array with the specified dtype and shape given by
shapeifshapeis not None, or else by broadcastinglowerandupper. Returns values in the open interval(lower, upper).- Return type: