one_hot

Contents

one_hot#

class brainstate.nn.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[source]#

One-hot encode the given indices.

Each index in the input x is encoded as a vector of zeros of length num_classes with the element at index set to one.

Indices outside the range [0, num_classes) will be encoded as zeros.

Parameters:
  • x (Any) – A tensor of indices.

  • num_classes (int) – Number of classes in the one-hot dimension.

  • dtype (Any) – The dtype for the returned values. Default is jnp.float_.

  • axis (int | Sequence[int]) – The axis or axes along which the function should be computed. Default is -1.

Returns:

One-hot encoded array.

Return type:

Array | Quantity

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>> brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

>>> # Indices outside the range are encoded as zeros
>>> brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)