anthe.sevenants

Getting BERT token embeddings WITHOUT position/segment embeddings

2022-04-07

BERT hidden states, even at layer zero, do not simply consist of a vector meaning representation. Even at the input level ("layer 0"), the BERT vector contains a segment as well as a position embedding in addition to the token embedding:

BERT input schematic

This is different to earlier approaches such as word2vec, which, because of their different training procedure and architecture, only contained the so-called "token embedding".

For use in NLP tasks, the segment and position embeddings are vital to the workings of the BERT model and should therefore never be stripped. However, for research, it might be interesting to get the "good old" uncontextualised, untampered word piece vectors. This article will outline how to get the 768-dimensional "token embedding" with the HuggingFace Transformers library and pytorch.

Because the hidden states of BERT already contain position and segment data (even at layer 0), we have to extract the token embeddings from the embedding matrix. This is the matrix of size (vocabulary size x embedding dimension) which is used to convert token ids to a pretrained vector. To get this embedding matrix, simply use the get_input_embeddings() method on your BERT model object:

embedding_matrix = model.get_input_embeddings()

>> Embedding(40000, 768, padding_idx=1)

This will return a pytorch embedding of size (vocabulary size x embedding dimension), just as expected.

We could convert this embedding matrix to a numpy vector and index it directly. However, that would be a waste of memory (and time), because lookup is already implemented in pytorch itself! You can simply use the embedding matrix as a function with your word piece indices as an input (in tensor form):

input_ids = torch.tensor([256, 10896, 742])
embeddings = model.get_input_embeddings()(input_ids)
embeddings.shape

>> torch.Size([3, 768])

You can then convert this output to a numpy matrix to apply dimensionality reduction for research purposes.