layout | background-class | body-class | title | summary | category | image | author | tags | github-link | github-id | featured_image_1 | featured_image_2 | accelerator | order | demo-model-link | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
hub_detail |
hub-background |
hub |
Transformer (NMT) |
์์ด-ํ๋์ค์ด ๋ฒ์ญ๊ณผ ์์ด-๋
์ผ์ด ๋ฒ์ญ์ ์ํ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ |
researchers |
fairseq_logo.png |
Facebook AI (fairseq Team) |
|
pytorch/fairseq |
no-image |
no-image |
cuda-optional |
2 |
๋
ผ๋ฌธ Attention Is All You Need์ ์๊ฐ๋์๋ ํธ๋์คํฌ๋จธ(Transformer)๋
๊ฐ๋ ฅํ ์ํ์ค-ํฌ-์ํ์ค ๋ชจ๋ธ๋ง ์ํคํ
์ฒ๋ก ์ต์ ๊ธฐ๊ณ ์ ๊ฒฝ๋ง ๋ฒ์ญ ์์คํ
์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
์ต๊ทผ, fairseq
ํ์ ์ญ๋ฒ์ญ๋ ๋ฐ์ดํฐ๋ฅผ ํ์ฉํ
ํธ๋์คํฌ๋จธ์ ๋๊ท๋ชจ ์ค์ง๋ ํ์ต์ ํตํด ๋ฒ์ญ ์์ค์ ๊ธฐ์กด๋ณด๋ค ํฅ์์์ผฐ์ต๋๋ค.
๋ ์์ธํ ๋ด์ฉ์ ๋ธ๋ก๊ทธ ํฌ์คํธ๋ฅผ ํตํด ์ฐพ์ผ์ค ์ ์์ต๋๋ค.
์ ์ฒ๋ฆฌ ๊ณผ์ ์ ์ํด ๋ช ๊ฐ์ง python ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค:
pip install bitarray fastBPE hydra-core omegaconf regex requests sacremoses subword_nmt
์์ด๋ฅผ ํ๋์ค์ด๋ก ๋ฒ์ญํ๊ธฐ ์ํด Scaling Neural Machine Translation ๋ ผ๋ฌธ์ ๋ชจ๋ธ์ ํ์ฉํฉ๋๋ค:
import torch
# WMT'14 data์์ ํ์ต๋ ์์ด โก๏ธ ํ๋์ค์ด ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ:
en2fr = torch.hub.load('pytorch/fairseq', 'transformer.wmt14.en-fr', tokenizer='moses', bpe='subword_nmt')
# GPU ์ฌ์ฉ (์ ํ์ฌํญ):
en2fr.cuda()
# beam search๋ฅผ ํตํ ๋ฒ์ญ:
fr = en2fr.translate('Hello world!', beam=5)
assert fr == 'Bonjour ร tous !'
# ํ ํฐํ:
en_toks = en2fr.tokenize('Hello world!')
assert en_toks == 'Hello world !'
# BPE ์ ์ฉ:
en_bpe = en2fr.apply_bpe(en_toks)
assert en_bpe == 'H@@ ello world !'
# ์ด์งํ:
en_bin = en2fr.binarize(en_bpe)
assert en_bin.tolist() == [329, 14044, 682, 812, 2]
# top-k sampling์ ํตํด ๋ค์ฏ ๋ฒ์ญ ์ฌ๋ก ์์ฑ:
fr_bin = en2fr.generate(en_bin, beam=5, sampling=True, sampling_topk=20)
assert len(fr_bin) == 5
# ์์์ค ํ๋๋ฅผ ๋ฌธ์์ด๋ก ๋ณํํ๊ณ ๋นํ ํฐํ
fr_sample = fr_bin[0]['tokens']
fr_bpe = en2fr.string(fr_sample)
fr_toks = en2fr.remove_bpe(fr_bpe)
fr = en2fr.detokenize(fr_toks)
assert fr == en2fr.decode(fr_sample)
์ญ๋ฒ์ญ์ ๋ํ ์ค์ง๋ํ์ต์ ๋ฒ์ญ ์์คํ ์ ํฅ์์ํค๋๋ฐ ํจ์จ์ ์ธ ๋ฐฉ๋ฒ์ ๋๋ค. ๋ ผ๋ฌธ Understanding Back-Translation at Scale์์, ์ถ๊ฐ์ ์ธ ํ์ต ๋ฐ์ดํฐ๋ก ์ฌ์ฉํ๊ธฐ ์ํด 2์ต๊ฐ ์ด์์ ๋ ์ผ์ด ๋ฌธ์ฅ์ ์ญ๋ฒ์ญํฉ๋๋ค. ์ด ๋ค์ฏ ๋ชจ๋ธ๋ค์ ์์๋ธ์ WMT'18 English-German news translation competition์ ์์์์ ๋๋ค.
noisy-channel reranking์ ํตํด ์ด ์ ๊ทผ๋ฒ์ ๋ ํฅ์์ํฌ ์ ์์ต๋๋ค. ๋ ์์ธํ ๋ด์ฉ์ ๋ธ๋ก๊ทธ ํฌ์คํธ์์ ๋ณผ ์ ์์ต๋๋ค. ์ด๋ฌํ ๋ ธํ์ฐ๋ก ํ์ต๋ ๋ชจ๋ธ๋ค์ ์์๋ธ์ WMT'19 English-German news translation competition์ ์์์์ ๋๋ค.
์์ ์๊ฐ๋ ๋ํ ์์ ๋ชจ๋ธ ์ค ํ๋๋ฅผ ์ฌ์ฉํ์ฌ ์์ด๋ฅผ ๋ ์ผ์ด๋ก ๋ฒ์ญํด๋ณด๊ฒ ์ต๋๋ค:
import torch
# WMT'19 data์์ ํ์ต๋ ์์ด โก๏ธ ๋
์ผ์ด ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ:
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')
# ๊ธฐ๋ณธ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ์ ์ ๊ทผ
assert isinstance(en2de.models[0], torch.nn.Module)
# ์์ด โก๏ธ ๋
์ผ์ด ๋ฒ์ญ
de = en2de.translate('PyTorch Hub is a pre-trained model repository designed to facilitate research reproducibility.')
assert de == 'PyTorch Hub ist ein vorgefertigtes Modell-Repository, das die Reproduzierbarkeit der Forschung erleichtern soll.'
๊ต์ฐจ๋ฒ์ญ์ผ๋ก ๊ฐ์ ๋ฌธ์ฅ์ ๋ํ ์์ญ์ ๋ง๋ค ์๋ ์์ต๋๋ค:
# ์์ด โ๏ธ ๋
์ผ์ด ๊ต์ฐจ๋ฒ์ญ:
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')
de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en.single_model', tokenizer='moses', bpe='fastbpe')
paraphrase = de2en.translate(en2de.translate('PyTorch Hub is an awesome interface!'))
assert paraphrase == 'PyTorch Hub is a fantastic interface!'
# ์์ด โ๏ธ ๋ฌ์์์ด ๊ต์ฐจ๋ฒ์ญ๊ณผ ๋น๊ต:
en2ru = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-ru.single_model', tokenizer='moses', bpe='fastbpe')
ru2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.ru-en.single_model', tokenizer='moses', bpe='fastbpe')
paraphrase = ru2en.translate(en2ru.translate('PyTorch Hub is an awesome interface!'))
assert paraphrase == 'PyTorch is a great interface!'