Skip to content

Latest commit

ย 

History

History
95 lines (72 loc) ยท 3.7 KB

pytorch_fairseq_roberta.md

File metadata and controls

95 lines (72 loc) ยท 3.7 KB
layout background-class body-class title summary category image author tags github-link github-id accelerator order demo-model-link
hub_detail
hub-background
hub
RoBERTa
BERT๋ฅผ ๊ฐ•๋ ฅํ•˜๊ฒŒ ์ตœ์ ํ™”ํ•˜๋Š” ์‚ฌ์ „ ํ•™์Šต ์ ‘๊ทผ๋ฒ•, RoBERTa
researchers
fairseq_logo.png
Facebook AI (fairseq Team)
nlp
pytorch/fairseq
cuda-optional
10

๋ชจ๋ธ ์„ค๋ช…

Bidirectional Encoder Representations from Transformers, BERT๋Š” ํ…์ŠคํŠธ์—์„œ ์˜๋„์ ์œผ๋กœ ์ˆจ๊ฒจ์ง„ ๋ถ€๋ถ„์„ ์˜ˆ์ธกํ•˜๋Š” ๋›ฐ์–ด๋‚œ ์ž๊ธฐ์ง€๋„ ์‚ฌ์ „ ํ•™์Šต(self-supervised pretraining) ๊ธฐ์ˆ ์ž…๋‹ˆ๋‹ค. ํŠนํžˆ BERT๊ฐ€ ํ•™์Šตํ•œ ํ‘œํ˜„์€ ๋‹ค์šด์ŠคํŠธ๋ฆผ ํƒœ์Šคํฌ(downstream tasks)์— ์ž˜ ์ผ๋ฐ˜ํ™”๋˜๋Š” ๊ฒƒ์œผ๋กœ ๋‚˜ํƒ€๋‚ฌ์œผ๋ฉฐ, BERT๊ฐ€ ์ฒ˜์Œ ๊ณต๊ฐœ๋œ 2018๋…„์— ์ˆ˜๋งŽ์€ ์ž์—ฐ์–ด์ฒ˜๋ฆฌ ๋ฒค์น˜๋งˆํฌ ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•ด ๊ฐ€์žฅ ์ข‹์€ ์„ฑ๋Šฅ์„ ๊ธฐ๋กํ–ˆ์Šต๋‹ˆ๋‹ค.

RoBERTa๋Š” BERT์˜ ์–ธ์–ด ๋งˆ์Šคํ‚น ์ „๋žต(language masking strategy)์— ๊ธฐ๋ฐ˜ํ•˜์ง€๋งŒ ๋ช‡ ๊ฐ€์ง€ ์ฐจ์ด์ ์ด ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ ๋ฌธ์žฅ ์‚ฌ์ „ ํ•™์Šต(next-sentence pretraining objective)์„ ์ œ๊ฑฐํ•˜๊ณ  ํ›จ์”ฌ ๋” ํฐ ๋ฏธ๋‹ˆ ๋ฐฐ์น˜์™€ ํ•™์Šต ์†๋„๋กœ ํ›ˆ๋ จํ•˜๋Š” ๋“ฑ ์ฃผ์š” ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ RoBERTa๋Š” ๋” ์˜ค๋žœ ์‹œ๊ฐ„ ๋™์•ˆ BERT๋ณด๋‹ค ํ›จ์”ฌ ๋งŽ์€ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด ํ•™์Šต๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด RoBERTa์˜ ํ‘œํ˜„์€ BERT๋ณด๋‹ค ๋‹ค์šด์ŠคํŠธ๋ฆผ ํƒœ์Šคํฌ์— ๋” ์ž˜ ์ผ๋ฐ˜ํ™”๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์š”๊ตฌ ์‚ฌํ•ญ

์ถ”๊ฐ€์ ์ธ Python ์˜์กด์„ฑ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

pip install regex requests hydra-core omegaconf

์˜ˆ์‹œ

RoBERTa ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
import torch
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
roberta.eval()  # ๋“œ๋กญ์•„์›ƒ ๋น„ํ™œ์„ฑํ™” (๋˜๋Š” ํ•™์Šต ๋ชจ๋“œ ๋น„ํ™œ์„ฑํ™”)
์ž…๋ ฅ ํ…์ŠคํŠธ์— Byte-Pair Encoding (BPE) ์ ์šฉํ•˜๊ธฐ
tokens = roberta.encode('Hello world!')
assert tokens.tolist() == [0, 31414, 232, 328, 2]
assert roberta.decode(tokens) == 'Hello world!'
RoBERTa์—์„œ ํŠน์ง•(feature) ์ถ”์ถœ
# ๋งˆ์ง€๋ง‰ ๊ณ„์ธต์˜ ํŠน์ง• ์ถ”์ถœ
last_layer_features = roberta.extract_features(tokens)
assert last_layer_features.size() == torch.Size([1, 5, 1024])

# ๋ชจ๋“  ๊ณ„์ธต์˜ ํŠน์ง• ์ถ”์ถœ
all_layers = roberta.extract_features(tokens, return_all_hiddens=True)
assert len(all_layers) == 25
assert torch.all(all_layers[-1] == last_layer_features)
๋ฌธ์žฅ ๊ด€๊ณ„ ๋ถ„๋ฅ˜(sentence-pair classification) ํƒœ์Šคํฌ์— RoBERTa ์‚ฌ์šฉํ•˜๊ธฐ
# MNLI์— ๋Œ€ํ•ด ๋ฏธ์„ธ์กฐ์ •๋œ RoBERTa ๋‹ค์šด๋กœ๋“œ
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
roberta.eval()  # ํ‰๊ฐ€๋ฅผ ์œ„ํ•ด ๋“œ๋กญ์•„์›ƒ ๋น„ํ™œ์„ฑํ™”

with torch.no_grad():
    # ํ•œ ์Œ์˜ ๋ฌธ์žฅ์„ ์ธ์ฝ”๋”ฉํ•˜๊ณ  ์˜ˆ์ธก
    tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.')
    prediction = roberta.predict('mnli', tokens).argmax().item()
    assert prediction == 0  # contradiction

    # ๋‹ค๋ฅธ ๋ฌธ์žฅ ์Œ์„ ์ธ์ฝ”๋”ฉํ•˜๊ณ  ์˜ˆ์ธก
    tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.')
    prediction = roberta.predict('mnli', tokens).argmax().item()
    assert prediction == 2  # entailment
์ƒˆ๋กœ์šด ๋ถ„๋ฅ˜์ธต ์ ์šฉํ•˜๊ธฐ
roberta.register_classification_head('new_task', num_classes=3)
logprobs = roberta.predict('new_task', tokens)  # tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)

์ฐธ๊ณ