| layout | hub_detail | |
|---|---|---|
| background-class | hub-background | |
| body-class | hub | |
| title | Transformer (NMT) | |
| summary | μμ΄-νλμ€μ΄ λ²μκ³Ό μμ΄-λ μΌμ΄ λ²μμ μν νΈλμ€ν¬λ¨Έ λͺ¨λΈ | |
| category | researchers | |
| image | fairseq_logo.png | |
| author | Facebook AI (fairseq Team) | |
| tags |
|
|
| github-link | https://github.com/pytorch/fairseq/ | |
| github-id | pytorch/fairseq | |
| featured_image_1 | no-image | |
| featured_image_2 | no-image | |
| accelerator | cuda-optional | |
| order | 2 | |
| demo-model-link | https://huggingface.co/spaces/pytorch/Transformer_NMT |
λ
Όλ¬Έ Attention Is All You Needμ μκ°λμλ νΈλμ€ν¬λ¨Έ(Transformer)λ
κ°λ ₯ν μνμ€-ν¬-μνμ€ λͺ¨λΈλ§ μν€ν
μ²λ‘ μ΅μ κΈ°κ³ μ κ²½λ§ λ²μ μμ€ν
μ κ°λ₯νκ² ν©λλ€.
μ΅κ·Ό, fairseqνμ μλ²μλ λ°μ΄ν°λ₯Ό νμ©ν
νΈλμ€ν¬λ¨Έμ λκ·λͺ¨ μ€μ§λ νμ΅μ ν΅ν΄ λ²μ μμ€μ κΈ°μ‘΄λ³΄λ€ ν₯μμμΌ°μ΅λλ€.
λ μμΈν λ΄μ©μ λΈλ‘κ·Έ ν¬μ€νΈλ₯Ό ν΅ν΄ μ°ΎμΌμ€ μ μμ΅λλ€.
μ μ²λ¦¬ κ³Όμ μ μν΄ λͺ κ°μ§ python λΌμ΄λΈλ¬λ¦¬κ° νμν©λλ€:
pip install bitarray fastBPE hydra-core omegaconf regex requests sacremoses subword_nmtμμ΄λ₯Ό νλμ€μ΄λ‘ λ²μνκΈ° μν΄ Scaling Neural Machine Translation λ Όλ¬Έμ λͺ¨λΈμ νμ©ν©λλ€:
import torch
# WMT'14 dataμμ νμ΅λ μμ΄ β‘οΈ νλμ€μ΄ νΈλμ€ν¬λ¨Έ λͺ¨λΈ λΆλ¬μ€κΈ°:
en2fr = torch.hub.load('pytorch/fairseq', 'transformer.wmt14.en-fr', tokenizer='moses', bpe='subword_nmt')
# GPU μ¬μ© (μ νμ¬ν):
en2fr.cuda()
# beam searchλ₯Ό ν΅ν λ²μ:
fr = en2fr.translate('Hello world!', beam=5)
assert fr == 'Bonjour Γ tous !'
# ν ν°ν:
en_toks = en2fr.tokenize('Hello world!')
assert en_toks == 'Hello world !'
# BPE μ μ©:
en_bpe = en2fr.apply_bpe(en_toks)
assert en_bpe == 'H@@ ello world !'
# μ΄μ§ν:
en_bin = en2fr.binarize(en_bpe)
assert en_bin.tolist() == [329, 14044, 682, 812, 2]
# top-k samplingμ ν΅ν΄ λ€μ― λ²μ μ¬λ‘ μμ±:
fr_bin = en2fr.generate(en_bin, beam=5, sampling=True, sampling_topk=20)
assert len(fr_bin) == 5
# μμμ€ νλλ₯Ό λ¬Έμμ΄λ‘ λ³ννκ³ λΉν ν°ν
fr_sample = fr_bin[0]['tokens']
fr_bpe = en2fr.string(fr_sample)
fr_toks = en2fr.remove_bpe(fr_bpe)
fr = en2fr.detokenize(fr_toks)
assert fr == en2fr.decode(fr_sample)μλ²μμ λν μ€μ§λνμ΅μ λ²μ μμ€ν μ ν₯μμν€λλ° ν¨μ¨μ μΈ λ°©λ²μ λλ€. λ Όλ¬Έ Understanding Back-Translation at Scaleμμ, μΆκ°μ μΈ νμ΅ λ°μ΄ν°λ‘ μ¬μ©νκΈ° μν΄ 2μ΅κ° μ΄μμ λ μΌμ΄ λ¬Έμ₯μ μλ²μν©λλ€. μ΄ λ€μ― λͺ¨λΈλ€μ μμλΈμ WMT'18 English-German news translation competitionμ μμμμ λλ€.
noisy-channel rerankingμ ν΅ν΄ μ΄ μ κ·Όλ²μ λ ν₯μμν¬ μ μμ΅λλ€. λ μμΈν λ΄μ©μ λΈλ‘κ·Έ ν¬μ€νΈμμ λ³Ό μ μμ΅λλ€. μ΄λ¬ν λ Ένμ°λ‘ νμ΅λ λͺ¨λΈλ€μ μμλΈμ WMT'19 English-German news translation competitionμ μμμμ λλ€.
μμ μκ°λ λν μμ λͺ¨λΈ μ€ νλλ₯Ό μ¬μ©νμ¬ μμ΄λ₯Ό λ μΌμ΄λ‘ λ²μν΄λ³΄κ² μ΅λλ€:
import torch
# WMT'19 dataμμ νμ΅λ μμ΄ β‘οΈ λ
μΌμ΄ νΈλμ€ν¬λ¨Έ λͺ¨λΈ λΆλ¬μ€κΈ°:
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')
# κΈ°λ³Έ νΈλμ€ν¬λ¨Έ λͺ¨λΈμ μ κ·Ό
assert isinstance(en2de.models[0], torch.nn.Module)
# μμ΄ β‘οΈ λ
μΌμ΄ λ²μ
de = en2de.translate('PyTorch Hub is a pre-trained model repository designed to facilitate research reproducibility.')
assert de == 'PyTorch Hub ist ein vorgefertigtes Modell-Repository, das die Reproduzierbarkeit der Forschung erleichtern soll.'κ΅μ°¨λ²μμΌλ‘ κ°μ λ¬Έμ₯μ λν μμμ λ§λ€ μλ μμ΅λλ€:
# μμ΄ βοΈ λ
μΌμ΄ κ΅μ°¨λ²μ:
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')
de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en.single_model', tokenizer='moses', bpe='fastbpe')
paraphrase = de2en.translate(en2de.translate('PyTorch Hub is an awesome interface!'))
assert paraphrase == 'PyTorch Hub is a fantastic interface!'
# μμ΄ βοΈ λ¬μμμ΄ κ΅μ°¨λ²μκ³Ό λΉκ΅:
en2ru = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-ru.single_model', tokenizer='moses', bpe='fastbpe')
ru2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.ru-en.single_model', tokenizer='moses', bpe='fastbpe')
paraphrase = ru2en.translate(en2ru.translate('PyTorch Hub is an awesome interface!'))
assert paraphrase == 'PyTorch is a great interface!'