Skip to content

Latest commit

ย 

History

History
143 lines (105 loc) ยท 5.39 KB

pytorch_fairseq_translation.md

File metadata and controls

143 lines (105 loc) ยท 5.39 KB
layout background-class body-class title summary category image author tags github-link github-id featured_image_1 featured_image_2 accelerator order demo-model-link
hub_detail
hub-background
hub
Transformer (NMT)
์˜์–ด-ํ”„๋ž‘์Šค์–ด ๋ฒˆ์—ญ๊ณผ ์˜์–ด-๋…์ผ์–ด ๋ฒˆ์—ญ์„ ์œ„ํ•œ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ
researchers
fairseq_logo.png
Facebook AI (fairseq Team)
nlp
pytorch/fairseq
no-image
no-image
cuda-optional
2

๋ชจ๋ธ ์„ค๋ช…

๋…ผ๋ฌธ Attention Is All You Need์— ์†Œ๊ฐœ๋˜์—ˆ๋˜ ํŠธ๋žœ์Šคํฌ๋จธ(Transformer)๋Š”
๊ฐ•๋ ฅํ•œ ์‹œํ€€์Šค-ํˆฌ-์‹œํ€€์Šค ๋ชจ๋ธ๋ง ์•„ํ‚คํ…์ฒ˜๋กœ ์ตœ์‹  ๊ธฐ๊ณ„ ์‹ ๊ฒฝ๋ง ๋ฒˆ์—ญ ์‹œ์Šคํ…œ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.

์ตœ๊ทผ, fairseqํŒ€์€ ์—ญ๋ฒˆ์—ญ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ํ™œ์šฉํ•œ ํŠธ๋žœ์Šคํฌ๋จธ์˜ ๋Œ€๊ทœ๋ชจ ์ค€์ง€๋„ ํ•™์Šต์„ ํ†ตํ•ด ๋ฒˆ์—ญ ์ˆ˜์ค€์„ ๊ธฐ์กด๋ณด๋‹ค ํ–ฅ์ƒ์‹œ์ผฐ์Šต๋‹ˆ๋‹ค. ๋” ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋ธ”๋กœ๊ทธ ํฌ์ŠคํŠธ๋ฅผ ํ†ตํ•ด ์ฐพ์œผ์‹ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์š”๊ตฌ์‚ฌํ•ญ

์ „์ฒ˜๋ฆฌ ๊ณผ์ •์„ ์œ„ํ•ด ๋ช‡ ๊ฐ€์ง€ python ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค:

pip install bitarray fastBPE hydra-core omegaconf regex requests sacremoses subword_nmt

์˜์–ด โžก๏ธ ํ”„๋ž‘์Šค์–ด ๋ฒˆ์—ญ

์˜์–ด๋ฅผ ํ”„๋ž‘์Šค์–ด๋กœ ๋ฒˆ์—ญํ•˜๊ธฐ ์œ„ํ•ด Scaling Neural Machine Translation ๋…ผ๋ฌธ์˜ ๋ชจ๋ธ์„ ํ™œ์šฉํ•ฉ๋‹ˆ๋‹ค:

import torch

# WMT'14 data์—์„œ ํ•™์Šต๋œ ์˜์–ด โžก๏ธ ํ”„๋ž‘์Šค์–ด ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ:
en2fr = torch.hub.load('pytorch/fairseq', 'transformer.wmt14.en-fr', tokenizer='moses', bpe='subword_nmt')

# GPU ์‚ฌ์šฉ (์„ ํƒ์‚ฌํ•ญ):
en2fr.cuda()

# beam search๋ฅผ ํ†ตํ•œ ๋ฒˆ์—ญ:
fr = en2fr.translate('Hello world!', beam=5)
assert fr == 'Bonjour ร  tous !'

# ํ† ํฐํ™”:
en_toks = en2fr.tokenize('Hello world!')
assert en_toks == 'Hello world !'

# BPE ์ ์šฉ:
en_bpe = en2fr.apply_bpe(en_toks)
assert en_bpe == 'H@@ ello world !'

# ์ด์ง„ํ™”:
en_bin = en2fr.binarize(en_bpe)
assert en_bin.tolist() == [329, 14044, 682, 812, 2]

# top-k sampling์„ ํ†ตํ•ด ๋‹ค์„ฏ ๋ฒˆ์—ญ ์‚ฌ๋ก€ ์ƒ์„ฑ:
fr_bin = en2fr.generate(en_bin, beam=5, sampling=True, sampling_topk=20)
assert len(fr_bin) == 5

# ์˜ˆ์‹œ์ค‘ ํ•˜๋‚˜๋ฅผ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ๋น„ํ† ํฐํ™”
fr_sample = fr_bin[0]['tokens']
fr_bpe = en2fr.string(fr_sample)
fr_toks = en2fr.remove_bpe(fr_bpe)
fr = en2fr.detokenize(fr_toks)
assert fr == en2fr.decode(fr_sample)

์˜์–ด โžก๏ธ ๋…์ผ์–ด ๋ฒˆ์—ญ

์—ญ๋ฒˆ์—ญ์— ๋Œ€ํ•œ ์ค€์ง€๋„ํ•™์Šต์€ ๋ฒˆ์—ญ ์‹œ์Šคํ…œ์„ ํ–ฅ์ƒ์‹œํ‚ค๋Š”๋ฐ ํšจ์œจ์ ์ธ ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ๋…ผ๋ฌธ Understanding Back-Translation at Scale์—์„œ, ์ถ”๊ฐ€์ ์ธ ํ•™์Šต ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด 2์–ต๊ฐœ ์ด์ƒ์˜ ๋…์ผ์–ด ๋ฌธ์žฅ์„ ์—ญ๋ฒˆ์—ญํ•ฉ๋‹ˆ๋‹ค. ์ด ๋‹ค์„ฏ ๋ชจ๋ธ๋“ค์˜ ์•™์ƒ๋ธ”์€ WMT'18 English-German news translation competition์˜ ์ˆ˜์ƒ์ž‘์ž…๋‹ˆ๋‹ค.

noisy-channel reranking์„ ํ†ตํ•ด ์ด ์ ‘๊ทผ๋ฒ•์„ ๋” ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋” ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋ธ”๋กœ๊ทธ ํฌ์ŠคํŠธ์—์„œ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋…ธํ•˜์šฐ๋กœ ํ•™์Šต๋œ ๋ชจ๋ธ๋“ค์˜ ์•™์ƒ๋ธ”์€ WMT'19 English-German news translation competition์˜ ์ˆ˜์ƒ์ž‘์ž…๋‹ˆ๋‹ค.

์•ž์„œ ์†Œ๊ฐœ๋œ ๋Œ€ํšŒ ์ˆ˜์ƒ ๋ชจ๋ธ ์ค‘ ํ•˜๋‚˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜์–ด๋ฅผ ๋…์ผ์–ด๋กœ ๋ฒˆ์—ญํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

import torch

# WMT'19 data์—์„œ ํ•™์Šต๋œ ์˜์–ด โžก๏ธ ๋…์ผ์–ด ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ:
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')

# ๊ธฐ๋ณธ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์— ์ ‘๊ทผ
assert isinstance(en2de.models[0], torch.nn.Module)

# ์˜์–ด โžก๏ธ ๋…์ผ์–ด ๋ฒˆ์—ญ
de = en2de.translate('PyTorch Hub is a pre-trained model repository designed to facilitate research reproducibility.')
assert de == 'PyTorch Hub ist ein vorgefertigtes Modell-Repository, das die Reproduzierbarkeit der Forschung erleichtern soll.'

๊ต์ฐจ๋ฒˆ์—ญ์œผ๋กœ ๊ฐ™์€ ๋ฌธ์žฅ์— ๋Œ€ํ•œ ์˜์—ญ์„ ๋งŒ๋“ค ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:

# ์˜์–ด โ†”๏ธ ๋…์ผ์–ด ๊ต์ฐจ๋ฒˆ์—ญ:
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')
de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en.single_model', tokenizer='moses', bpe='fastbpe')

paraphrase = de2en.translate(en2de.translate('PyTorch Hub is an awesome interface!'))
assert paraphrase == 'PyTorch Hub is a fantastic interface!'

# ์˜์–ด โ†”๏ธ ๋Ÿฌ์‹œ์•„์–ด ๊ต์ฐจ๋ฒˆ์—ญ๊ณผ ๋น„๊ต:
en2ru = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-ru.single_model', tokenizer='moses', bpe='fastbpe')
ru2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.ru-en.single_model', tokenizer='moses', bpe='fastbpe')

paraphrase = ru2en.translate(en2ru.translate('PyTorch Hub is an awesome interface!'))
assert paraphrase == 'PyTorch is a great interface!'

์ฐธ๊ณ  ๋ฌธํ—Œ