one_hot#
- class brainstate.nn.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[source]#
One-hot encode the given indices.
Each index in the input
xis encoded as a vector of zeros of lengthnum_classeswith the element atindexset to one.Indices outside the range [0, num_classes) will be encoded as zeros.
- Parameters:
- Returns:
One-hot encoded array.
- Return type:
Array|Quantity
Examples
>>> import jax.numpy as jnp >>> import brainstate >>> brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) >>> # Indices outside the range are encoded as zeros >>> brainstate.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)