dirichlet#
- class brainstate.random.dirichlet(alpha, size=None, key=None, dtype=None)#
Draw samples from the Dirichlet distribution.
Draw size samples of dimension k from a Dirichlet distribution. A Dirichlet-distributed random variable can be seen as a multivariate generalization of a Beta distribution. The Dirichlet distribution is a conjugate prior of a multinomial distribution in Bayesian inference.
- Parameters:
alpha (sequence of floats, length k) – Parameter of the distribution (length
kfor sample of lengthk).size (
int|Sequence[int] |integer|Sequence[integer] |None) – Output shape. If the given shape is, e.g.,(m, n), thenm * n * ksamples are drawn. Default is None, in which case a vector of lengthkis returned.key (
int|Array|ndarray|None) – The key for the random number generator. If not given, the default random number generator is used.
- Returns:
samples – The drawn samples, of shape
(size, k).- Return type:
ndarray,
- Raises:
ValueError – If any value in
alphais less than or equal to zero
Notes
The Dirichlet distribution is a distribution over vectors \(x\) that fulfil the conditions \(x_i>0\) and \(\sum_{i=1}^k x_i = 1\).
The probability density function \(p\) of a Dirichlet-distributed random vector \(X\) is proportional to
\[p(x) \propto \prod_{i=1}^{k}{x^{\alpha_i-1}_i},\]where \(\alpha\) is a vector containing the positive concentration parameters.
The method uses the following property for computation: let \(Y\) be a random vector which has components that follow a standard gamma distribution, then \(X = \frac{1}{\sum_{i=1}^k{Y_i}} Y\) is Dirichlet-distributed
References
Examples
Taking an example cited in Wikipedia, this distribution can be used if one wanted to cut strings (each of initial length 1.0) into K pieces with different lengths, where each piece had, on average, a designated average length, but allowing some variation in the relative sizes of the pieces.
>>> import brainstate >>> s = brainstate.random.dirichlet((10, 5, 3), 20).transpose()
>>> import matplotlib.pyplot as plt # noqa >>> plt.barh(range(20), s[0]) >>> plt.barh(range(20), s[1], left=s[0], color='g') >>> plt.barh(range(20), s[2], left=s[0]+s[1], color='r') >>> plt.title("Lengths of Strings")