categorical#
- class brainstate.random.categorical(logits, axis=-1, size=None, key=None)#
Sample random values from categorical distributions.
- Parameters:
logits – Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.
axis (
int) – Axis along which logits belong to the same categorical distribution.shape – Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with
np.delete(logits.shape, axis). The default (None) produces a result shape equal tonp.delete(logits.shape, axis).key (
int|Array|ndarray|None) – a PRNG key used as the random key.
- Returns:
A random array with int dtype and shape given by
shapeifshapeis not None, or elsenp.delete(logits.shape, axis).