Embedding#

class brainstate.nn.Embedding(num_embeddings, embedding_size, embedding_init=LecunUniform(   scale=1.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='uniform', rng=RandomState([ 900 9244]), unit=Unit("1") ), padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, freeze=False, name=None, param_type=<class 'brainstate.ParamState'>)[source]#

A simple lookup table that stores embeddings of a fixed size.

This module is commonly used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.

Parameters:
  • num_embeddings (int) – Size of embedding dictionary. Must be non-negative.

  • embedding_size (int | Sequence[int] | integer | Sequence[integer]) – Size of each embedding vector. Can be an int or a sequence of ints, and must contain only non-negative values.

  • embedding_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity) – The initializer for the embedding lookup table, of shape (num_embeddings, embedding_size). Default is LecunUniform().

  • padding_idx (int | None) – If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e., it remains as a fixed “pad”. For a newly constructed Embedding, the embedding vector at padding_idx will default to all zeros. Default is None.

  • max_norm (float | None) – If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Default is None.

  • norm_type (float) – The p of the p-norm to compute for the max_norm option. Default is 2.0.

  • scale_grad_by_freq (bool) – If given, this scales gradients by the inverse frequency of the words in the mini-batch. Default is False.

  • name (str | None) – The name of the module.

  • param_type (type) – The parameter state type to use. Default is ParamState.

num_embeddings#

Size of the embedding dictionary.

Type:

int

embedding_size#

Size of each embedding vector.

Type:

tuple[int, …]

out_size#

Output size, same as embedding_size.

Type:

tuple[int, …]

weight#

The learnable weights of the module of shape (num_embeddings, *embedding_size).

Type:

ParamState

padding_idx#

Index of the padding token.

Type:

int or None

max_norm#

Maximum norm for embedding vectors.

Type:

float or None

norm_type#

Type of p-norm to compute for max_norm.

Type:

float

scale_grad_by_freq#

Whether to scale gradients by frequency.

Type:

bool

freeze#

Whether the embedding weights are frozen.

Type:

bool

Examples

Create an embedding layer with 10 words and 3-dimensional embeddings:

>>> import brainstate as brainstate
>>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3)
>>> embedding.weight.value.shape
(10, 3)

Retrieve embeddings for specific indices:

>>> import jax.numpy as jnp
>>> indices = jnp.array([1, 3, 5])
>>> output = embedding(indices)
>>> output.shape
(3, 3)

Use with a batch of sequences:

>>> # Batch of 2 sequences, each with 4 tokens
>>> batch_indices = jnp.array([[1, 2, 3, 4],
...                            [5, 6, 7, 8]])
>>> output = embedding(batch_indices)
>>> output.shape
(2, 4, 3)

Use padding_idx to keep padding embeddings fixed:

>>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3, padding_idx=0)
>>> # The embedding at index 0 will remain zeros and not be updated during training
>>> indices = jnp.array([0, 2, 0, 5])
>>> output = embedding(indices)
>>> output[0]  # Will be zeros
Array([0., 0., 0.], dtype=float32)

Use max_norm to constrain embedding norms:

>>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3, max_norm=1.0)
>>> # All embeddings accessed in a forward pass are renormalized to have norm <= 1.0

Load pretrained embeddings:

>>> import brainstate
>>> import jax.numpy as jnp
>>> pretrained = jnp.array([[1.0, 2.0, 3.0],
...                         [4.0, 5.0, 6.0],
...                         [7.0, 8.0, 9.0]])
>>> embedding = brainstate.nn.Embedding.from_pretrained(pretrained, param_type=brainstate.FakeState)
>>> embedding.weight.value.shape
(3, 3)
classmethod from_pretrained(embeddings, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, freeze=True, name=None, param_type=<class 'brainstate.ParamState'>)[source]#

Create an Embedding instance from given 2-dimensional array.

Parameters:
  • embeddings (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Array containing weights for the Embedding. First dimension is passed to Embedding as num_embeddings, remaining dimensions as embedding_size.

  • padding_idx (int | None) – If specified, the entries at padding_idx do not contribute to the gradient. Default is None.

  • max_norm (float | None) – See module initialization documentation. Default is None.

  • norm_type (float) – See module initialization documentation. Default is 2.0.

  • scale_grad_by_freq (bool) – See module initialization documentation. Default is False.

  • freeze (bool) – If True, embeddings are frozen (no gradients). Default is True.

  • name (str | None) – The name of the module.

Returns:

An Embedding module with pretrained weights.

Return type:

Embedding

Examples

Load pretrained word embeddings:

>>> import jax.numpy as jnp
>>> import brainstate as brainstate
>>> pretrained = jnp.array([[1.0, 2.0, 3.0],
...                         [4.0, 5.0, 6.0],
...                         [7.0, 8.0, 9.0]])
>>> embedding = brainstate.nn.Embedding.from_pretrained(pretrained)
>>> embedding.weight.value.shape
(3, 3)
>>> indices = jnp.array([1])
>>> embedding(indices)
Array([[4., 5., 6.]], dtype=float32)
update(indices)[source]#

Retrieve embeddings for the given indices.

Parameters:

indices (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Indices to retrieve embeddings for. Can be any shape.

Returns:

Embeddings corresponding to the indices, with shape (*indices.shape, *embedding_size).

Return type:

ArrayLike