Embedding#
- class brainstate.nn.Embedding(num_embeddings, embedding_size, embedding_init=LecunUniform( scale=1.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='uniform', rng=RandomState([ 900 9244]), unit=Unit("1") ), padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, freeze=False, name=None, param_type=<class 'brainstate.ParamState'>)[source]#
A simple lookup table that stores embeddings of a fixed size.
This module is commonly used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.
- Parameters:
num_embeddings (
int) – Size of embedding dictionary. Must be non-negative.embedding_size (
int|Sequence[int] |integer|Sequence[integer]) – Size of each embedding vector. Can be an int or a sequence of ints, and must contain only non-negative values.embedding_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity) – The initializer for the embedding lookup table, of shape(num_embeddings, embedding_size). Default isLecunUniform().padding_idx (
int|None) – If specified, the entries atpadding_idxdo not contribute to the gradient; therefore, the embedding vector atpadding_idxis not updated during training, i.e., it remains as a fixed “pad”. For a newly constructed Embedding, the embedding vector atpadding_idxwill default to all zeros. Default isNone.max_norm (
float|None) – If given, each embedding vector with norm larger thanmax_normis renormalized to have normmax_norm. Default isNone.norm_type (
float) – The p of the p-norm to compute for themax_normoption. Default is2.0.scale_grad_by_freq (
bool) – If given, this scales gradients by the inverse frequency of the words in the mini-batch. Default isFalse.param_type (
type) – The parameter state type to use. Default isParamState.
- weight#
The learnable weights of the module of shape
(num_embeddings, *embedding_size).- Type:
Examples
Create an embedding layer with 10 words and 3-dimensional embeddings:
>>> import brainstate as brainstate >>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3) >>> embedding.weight.value.shape (10, 3)
Retrieve embeddings for specific indices:
>>> import jax.numpy as jnp >>> indices = jnp.array([1, 3, 5]) >>> output = embedding(indices) >>> output.shape (3, 3)
Use with a batch of sequences:
>>> # Batch of 2 sequences, each with 4 tokens >>> batch_indices = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8]]) >>> output = embedding(batch_indices) >>> output.shape (2, 4, 3)
Use padding_idx to keep padding embeddings fixed:
>>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3, padding_idx=0) >>> # The embedding at index 0 will remain zeros and not be updated during training >>> indices = jnp.array([0, 2, 0, 5]) >>> output = embedding(indices) >>> output[0] # Will be zeros Array([0., 0., 0.], dtype=float32)
Use max_norm to constrain embedding norms:
>>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3, max_norm=1.0) >>> # All embeddings accessed in a forward pass are renormalized to have norm <= 1.0
Load pretrained embeddings:
>>> import brainstate >>> import jax.numpy as jnp >>> pretrained = jnp.array([[1.0, 2.0, 3.0], ... [4.0, 5.0, 6.0], ... [7.0, 8.0, 9.0]]) >>> embedding = brainstate.nn.Embedding.from_pretrained(pretrained, param_type=brainstate.FakeState) >>> embedding.weight.value.shape (3, 3)
- classmethod from_pretrained(embeddings, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, freeze=True, name=None, param_type=<class 'brainstate.ParamState'>)[source]#
Create an Embedding instance from given 2-dimensional array.
- Parameters:
embeddings (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Array containing weights for the Embedding. First dimension is passed to Embedding asnum_embeddings, remaining dimensions asembedding_size.padding_idx (
int|None) – If specified, the entries atpadding_idxdo not contribute to the gradient. Default isNone.max_norm (
float|None) – See module initialization documentation. Default isNone.norm_type (
float) – See module initialization documentation. Default is2.0.scale_grad_by_freq (
bool) – See module initialization documentation. Default isFalse.freeze (
bool) – IfTrue, embeddings are frozen (no gradients). Default isTrue.
- Returns:
An Embedding module with pretrained weights.
- Return type:
Examples
Load pretrained word embeddings:
>>> import jax.numpy as jnp >>> import brainstate as brainstate >>> pretrained = jnp.array([[1.0, 2.0, 3.0], ... [4.0, 5.0, 6.0], ... [7.0, 8.0, 9.0]]) >>> embedding = brainstate.nn.Embedding.from_pretrained(pretrained) >>> embedding.weight.value.shape (3, 3) >>> indices = jnp.array([1]) >>> embedding(indices) Array([[4., 5., 6.]], dtype=float32)