truncated_normal

Contents

truncated_normal#

class brainstate.random.truncated_normal(lower, upper, size=None, loc=0.0, scale=1.0, key=None, dtype=None, check_valid=True)#

Sample truncated standard normal random values with given shape and dtype.

Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf

Notes

This distribution is the normal distribution centered on loc (default 0), with standard deviation scale (default 1), and clipped at a, b standard deviations to the left, right (respectively) from loc. If myclip_a and myclip_b are clip values in the sample space (as opposed to the number of standard deviations) then they can be converted to the required form according to:

a, b = (myclip_a - loc) / scale, (myclip_b - loc) / scale
Parameters:
  • lower (float, ndarray) – A float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with upper.

  • upper (float, ndarray) – A float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with lower.

  • loc (optional, float, ndarray) – Mean (“centre”) of the distribution before truncating. Note that the mean of the truncated distribution will not be exactly equal to loc.

  • size (int | Sequence[int] | integer | Sequence[integer] | None) – A tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with lower and upper. The default (None) produces a result shape by broadcasting lower and upper.

  • loc – A float or array of floats representing the mean of the distribution. Default is 0.

  • scale (float, ndarray) – Standard deviation (spread or “width”) of the distribution. Must be non-negative. Default is 1.

  • dtype (str | type[Any] | dtype | SupportsDType) – The float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • key (int | Array | ndarray | None) – The key for the random number generator. If not given, the default random number generator is used.

  • check_valid (bool) – Whether to check the validity of the input parameters. Default is True.

Returns:

out – A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting lower and upper. Returns values in the open interval (lower, upper).

Return type:

Array