Monday, August 1, 2022

What tokens are most similar to , , , and in roberta-large?

>>> import torch

>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')

>>> def topk_similar_tokens(roberta, index, k, normalize=False):

...   embed_tokens = roberta.get_parameter('model.encoder.sentence_encoder.embed_tokens.weight')

...   if normalize:

...     embed_tokens = embed_tokens / embed_tokens.norm(dim=1, keepdim=True)

...   _, indices = (embed_tokens[index] @ embed_tokens.T).topk(k)

...   return indices

... 

>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.mask_idx, 10))

' TM JC CSI Zeus Karma CG BG GG MM Harmony'

>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.bos(), 10))

'.?*.,. *.!.,,.*+.-.'

>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.pad(), 10))

Traceback (most recent call last):

  File "<stdin>", line 1, in <module>

  File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/models/roberta/hub_interface.py", line 75, in decode

    sentences = [

  File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/models/roberta/hub_interface.py", line 76, in <listcomp>

    self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences

  File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/data/encoders/gpt2_bpe.py", line 41, in decode

    [int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]

  File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/data/encoders/gpt2_bpe.py", line 41, in <listcomp>

    [int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]

ValueError: invalid literal for int() with base 10: '<pad>'

>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.pad(), 10)[1:])

'channelAvailability\x05EngineDebug<|endoftext|>PsyNetMessage 裏覚醒'

>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.eos(), 10))

' \u200b,, .... TM ..….. \u200e…… MM'

>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.mask_idx, 10, normalize=True))

'<mask> the and, to. that in GG'

>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.bos(), 10, normalize=True))

'<mask>. the, a!。?'

>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.pad(), 10, normalize=True))

Traceback (most recent call last):

  File "<stdin>", line 1, in <module>

  File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/models/roberta/hub_interface.py", line 75, in decode

    sentences = [

  File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/models/roberta/hub_interface.py", line 76, in <listcomp>

    self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences

  File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/data/encoders/gpt2_bpe.py", line 41, in decode

    [int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]

  File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/data/encoders/gpt2_bpe.py", line 41, in <listcomp>

    [int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]

ValueError: invalid literal for int() with base 10: '<pad>'

>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.pad(), 10, normalize=True)[1:])

'channelAvailabilityPsyNetMessage guiIconNetMessage\x05?????-?????-\x1bEStreamFrame'

>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.eos(), 10, normalize=True))

'.<mask>, ( " and the The-'


Conclusion: in terms of their embeddings, they are most similar to meaningless tokens.
(Unnormalized inner product is informative here since 
untie_weights_roberta=False)

No comments: