>>> import torch
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
>>> def topk_similar_tokens(roberta, index, k, normalize=False):
... embed_tokens = roberta.get_parameter('model.encoder.sentence_encoder.embed_tokens.weight')
... if normalize:
... embed_tokens = embed_tokens / embed_tokens.norm(dim=1, keepdim=True)
... _, indices = (embed_tokens[index] @ embed_tokens.T).topk(k)
... return indices
...
>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.mask_idx, 10))
' TM JC CSI Zeus Karma CG BG GG MM Harmony'
>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.bos(), 10))
'.?*.,. *.!.,,.*+.-.'
>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.pad(), 10))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/models/roberta/hub_interface.py", line 75, in decode
sentences = [
File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/models/roberta/hub_interface.py", line 76, in <listcomp>
self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences
File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/data/encoders/gpt2_bpe.py", line 41, in decode
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/data/encoders/gpt2_bpe.py", line 41, in <listcomp>
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
ValueError: invalid literal for int() with base 10: '<pad>'
>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.pad(), 10)[1:])
'channelAvailability\x05EngineDebug<|endoftext|>PsyNetMessage 裏覚醒'
>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.eos(), 10))
' \u200b,, .... TM ..….. \u200e…… MM'
>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.mask_idx, 10, normalize=True))
'<mask> the and, to. that in GG'
>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.bos(), 10, normalize=True))
'<mask>. the, a!。?'
>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.pad(), 10, normalize=True))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/models/roberta/hub_interface.py", line 75, in decode
sentences = [
File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/models/roberta/hub_interface.py", line 76, in <listcomp>
self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences
File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/data/encoders/gpt2_bpe.py", line 41, in decode
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
File "/Users/jason_chou/.cache/torch/hub/pytorch_fairseq_main/fairseq/data/encoders/gpt2_bpe.py", line 41, in <listcomp>
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
ValueError: invalid literal for int() with base 10: '<pad>'
>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.pad(), 10, normalize=True)[1:])
'channelAvailabilityPsyNetMessage guiIconNetMessage\x05?????-?????-\x1bEStreamFrame'
>>> roberta.decode(topk_similar_tokens(roberta, roberta.task.source_dictionary.eos(), 10, normalize=True))
'.<mask>, ( " and the The-'
Conclusion: in terms of their embeddings, they are most similar to meaningless tokens.
(Unnormalized inner product is informative here since untie_weights_roberta=False)
No comments:
Post a Comment