Skip to content

Commit

Permalink
update readme and fix dataloader bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Embedding committed Aug 6, 2019
1 parent 3ef8005 commit 9d4e5f8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 38 deletions.
75 changes: 39 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@

<img src="uer-logo.jpg" width="390" hegiht="390" align=left />

Pre-training has become an essential part for NLP tasks and has led to remarkable improvements. UER-py is a toolkit for pre-training on general-domain corpus and fine-tuning on downstream task. UER-py maintains model modularity and supports research extensibility. It facilitates the use of different pre-training models (e.g. BERT, GPT, ELMO), and provides interfaces for users to further extend upon. With UER-py, a model zoo is built which contains pre-trained models based on different corpora, encoders, and targets.
Pre-training has become an essential part for NLP tasks and has led to remarkable improvements. UER-py (Universal Encoder Representations) is a toolkit for pre-training on general-domain corpus and fine-tuning on downstream task. UER-py maintains model modularity and supports research extensibility. It facilitates the use of different pre-training models (e.g. BERT, GPT, ELMO), and provides interfaces for users to further extend upon. With UER-py, we build a model zoo which contains pre-trained models based on different corpora, encoders, and targets.

**Update: Now [pretrained GPT model (512 length)](https://share.weiyun.com/51nTP8V) is available. One can use *generate.py* to generate text.
ELMO (bilstm encoder + bilm target) is supported by UER.
Pre-trained word-based BERT is available. Context-dependent word embedding (trained by BERT) is in particular suitable for polysemous words.**

<br>

Expand Down Expand Up @@ -138,21 +135,21 @@ We can achieve 86.5 accuracy on testset, which is also a competitive result. Usi
Besides classification, UER-py also provides scripts for other downstream tasks. We could run_ner.py for named entity recognition:
```
python3 run_ner.py --pretrained_model_path models/google_model.bin --vocab_path models/google_vocab.txt \
--train_path datasets/msra/train.tsv --dev_path datasets/msra/dev.tsv --test_path datasets/msra/test.tsv \
--train_path datasets/msra_ner/train.tsv --dev_path datasets/msra_ner/dev.tsv --test_path datasets/msra_ner/test.tsv \
--epochs_num 5 --batch_size 16 --encoder bert
```
We could download [a model pre-trained on RenMinRiBao (as known as People's Daily, a news corpus)](https://share.weiyun.com/5JWVjSE) and finetune on it:
```
python3 run_ner.py --pretrained_model_path models/rmrb_model.bin --vocab_path models/google_vocab.txt \
--train_path datasets/msra/train.tsv --dev_path datasets/msra/dev.tsv --test_path datasets/msra/test.tsv \
--train_path datasets/msra_ner/train.tsv --dev_path datasets/msra_ner/dev.tsv --test_path datasets/msra_ner/test.tsv \
--epochs_num 5 --batch_size 16 --encoder bert
```
It turns out that the result of Google's model is 92.6; The result of *rmrb_model.bin* is 94.4.

<br/>

## Datasets
This project includes a range of Chinese datasets: XNLI, LCQMC, MSRA-NER, ChnSentiCorp, and NLPCC-DBQA are obtained from [Baidu ERNIE](https://github.com/PaddlePaddle/LARK/tree/develop/ERNIE); Douban book review is obtained from [BNU](https://embedding.github.io/evaluation/) and Online shopping review are organized by ourself. Large-scale datasets can be found in [glyph's github project](https://github.com/zhangxiangxiao/glyph).
This project includes a range of Chinese datasets: XNLI, LCQMC, MSRA-NER, ChnSentiCorp, and NLPCC-DBQA are obtained from [Baidu ERNIE](https://github.com/PaddlePaddle/LARK/tree/develop/ERNIE); Douban book review is obtained from [BNU](https://embedding.github.io/evaluation/); Online shopping review are organized by ourself; THUCNews is obtained from [here](https://github.com/gaussic/text-classification-cnn-rnn); Sina Weibo review is obtained from [here](https://github.com/SophonPlus/ChineseNlpCorpus); More Large-scale datasets can be found in [glyph's github project](https://github.com/zhangxiangxiao/glyph).

<table>
<tr align="center"><td> Dataset <td> Link
Expand Down Expand Up @@ -206,23 +203,24 @@ usage: preprocess.py [-h] --corpus_path CORPUS_PATH --vocab_path VOCAB_PATH
[--dataset_path DATASET_PATH]
[--tokenizer {bert,char,space}]
[--processes_num PROCESSES_NUM]
[--target {bert,lm,cls,mlm,nsp,s2s}]
[--target {bert,lm,cls,mlm,nsp,s2s,bilm}]
[--docs_buffer_size DOCS_BUFFER_SIZE]
[--instances_buffer_size INSTANCES_BUFFER_SIZE]
[--seq_length SEQ_LENGTH] [--dup_factor DUP_FACTOR]
[--short_seq_prob SHORT_SEQ_PROB] [--seed SEED]
```
*--docs_buffer_size* and *--instances_buffer_size* could be used to control memory consumption in pre-processing and pre-training stages. *--preprocesses_num n* denotes that n processes are used for pre-processing. The example of pre-processing on a single machine is as follows:
```
python3 preprocess.py --corpus_path corpora/book_review_bert.txt --vocab_path models/google_vocab.txt \
--dataset_path dataset.pt --processes_num 8 --target bert
python3 preprocess.py --corpus_path corpora/book_review_bert.txt --vocab_path models/google_vocab.txt --dataset_path dataset.pt\
--processes_num 8 --target bert
```
We need to specify the model's target in pre-processing stage since different targets require different data formats. Currently, UER-py consists of the following target modules:
- lm_target.py: language model
- mlm_target.py: masked language model (cloze test)
- nsp_target.py: next sentence prediction
- cls_target.py: classification
- s2s_target.py: supports autoencoder and machine translation
- bilm_target.py: bi-directional language model
- bert_target.py: masked language model + next sentence prediction

If multiple machines are available, each machine contains a part of corpus. The command is identical with the single machine case.
Expand All @@ -236,21 +234,26 @@ usage: pretrain.py [-h] [--dataset_path DATASET_PATH] --vocab_path VOCAB_PATH
[--save_checkpoint_steps SAVE_CHECKPOINT_STEPS]
[--report_steps REPORT_STEPS]
[--accumulation_steps ACCUMULATION_STEPS]
[--batch_size BATCH_SIZE]
[--emb_size EMB_SIZE] [--hidden_size HIDDEN_SIZE]
[--batch_size BATCH_SIZE] [--emb_size EMB_SIZE]
[--hidden_size HIDDEN_SIZE]
[--feedforward_size FEEDFORWARD_SIZE]
[--kernel_size KERNEL_SIZE] [--heads_num HEADS_NUM]
[--layers_num LAYERS_NUM] [--dropout DROPOUT] [--seed SEED]
[--encoder {bert,lstm,gru,cnn,gatedcnn,attn,rcnn,crnn,gpt}]
[--bidirectional] [--target {bert,lm,cls,mlm,nsp,s2s}]
[--kernel_size KERNEL_SIZE] [--block_size BLOCK_SIZE]
[--heads_num HEADS_NUM] [--layers_num LAYERS_NUM]
[--dropout DROPOUT] [--seed SEED]
[--encoder {bert,lstm,gru,cnn,gatedcnn,attn,rcnn,crnn,gpt,bilstm}]
[--bidirectional] [--target {bert,lm,cls,mlm,nsp,s2s,bilm}]
[--labels_num LABELS_NUM] [--learning_rate LEARNING_RATE]
[--warmup WARMUP] [--world_size WORLD_SIZE]
[--warmup WARMUP] [--subword_type {none,char}]
[--sub_vocab_path SUB_VOCAB_PATH]
[--subencoder {avg,lstm,gru,cnn}]
[--sub_layers_num SUB_LAYERS_NUM] [--world_size WORLD_SIZE]
[--gpu_ranks GPU_RANKS [GPU_RANKS ...]]
[--master_ip MASTER_IP] [--backend {nccl,gloo}]
```

Notice that it is recommended to explicitly specify model's encoder and target. UER-py consists of the following encoder modules:
- rnn_encoder.py: contains (bi-)LSTM and (bi-)GRU
- birnn_encoder.py: contains bi-LSTM and bi-GRU (different from rnn_encoder.py with --bidirectional, see [here](https://github.com/pytorch/pytorch/issues/4930) for more details)
- cnn_encoder.py: contains CNN and gatedCNN
- attn_encoder.py: contains attentionNN
- gpt_encoder.py: contains GPT encoder
Expand Down Expand Up @@ -321,8 +324,9 @@ python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/
python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_vocab.txt --pretrained_model_path models/google_model.bin --output_model_path models/output_model.bin \
--world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --total_steps 20000 --save_checkpoint_steps 5000 --encoder bert --target mlm
```
Notice that different targets correspond to different corpus formats. It is important to select proper format for a target.
If we want to change encoder, only thing we need to do is to specify --encoder in pretrain.py. Here is an example of using LSTM for pre-training.
*book_review.txt* (instead of *book_review_bert.txt*) is used as training corpus when we use MLM target. Different targets correspond to different corpus formats. It is important to select proper format for a target.

If we want to change encoder, we need to specify *--encoder* and *--config_path* in pretrain.py. Here is an example of using LSTM for pre-training.
```
python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_vocab.txt --dataset_path dataset.pt --processes_num 8 --target lm
Expand All @@ -333,7 +337,7 @@ python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_vocab.t


### Fine-tune on downstream tasks
Currently, UER-py consists of 4 downstream tasks, i.e. classification, sequence labeling, cloze test, feature extractor. The encoder of downstream task should be coincident with the pre-trained model.
Currently, UER-py consists of the following downstream tasks: text classification, pair classification, document-based question answering, sequence labeling, and machine reading comprehension. The encoder of the downstream task should be coincident with the pre-trained model.

#### Classification
run_classifier.py adds two feedforward layers upon encoder layer.
Expand All @@ -344,13 +348,13 @@ usage: run_classifier.py [-h] [--pretrained_model_path PRETRAINED_MODEL_PATH]
--dev_path DEV_PATH --test_path TEST_PATH
[--config_path CONFIG_PATH] [--batch_size BATCH_SIZE]
[--seq_length SEQ_LENGTH]
[--encoder {bert,lstm,gru,cnn,gatedcnn,attn,rcnn,crnn,gpt}]
[--encoder {bert,lstm,gru,cnn,gatedcnn,attn,rcnn,crnn,gpt,bilstm}]
[--bidirectional] [--pooling {mean,max,first,last}]
[--subword_type {none,char}]
[--sub_vocab_path SUB_VOCAB_PATH]
[--subencoder {avg,lstm,gru,cnn}]
[--sub_layers_num SUB_LAYERS_NUM]
[--tokenizer {bert,char,word,space}]
[--tokenizer {bert,char,space}]
[--learning_rate LEARNING_RATE] [--warmup WARMUP]
[--dropout DROPOUT] [--epochs_num EPOCHS_NUM]
[--report_steps REPORT_STEPS] [--seed SEED]
Expand Down Expand Up @@ -379,24 +383,23 @@ python3 run_classifier.py --pretrained_model_path models/google_model.bin --voca
run_ner.py adds two feedforward layers upon encoder layer.
```
usage: run_ner.py [-h] [--pretrained_model_path PRETRAINED_MODEL_PATH]
[--output_model_path OUTPUT_MODEL_PATH]
[--vocab_path VOCAB_PATH] [--train_path TRAIN_PATH]
[--dev_path DEV_PATH] [--test_path TEST_PATH]
[--config_path CONFIG_PATH] [--batch_size BATCH_SIZE]
[--seq_length SEQ_LENGTH]
[--encoder {bert,lstm,gru,cnn,gatedcnn,attn,rcnn,crnn,gpt}]
[--bidirectional] [--subword_type {none,char}]
[--sub_vocab_path SUB_VOCAB_PATH]
[--subencoder {avg,lstm,gru,cnn}]
[--sub_layers_num SUB_LAYERS_NUM]
[--learning_rate LEARNING_RATE] [--warmup WARMUP]
[--dropout DROPOUT] [--epochs_num EPOCHS_NUM]
[--report_steps REPORT_STEPS] [--seed SEED]
[--output_model_path OUTPUT_MODEL_PATH]
[--vocab_path VOCAB_PATH] --train_path TRAIN_PATH --dev_path
DEV_PATH --test_path TEST_PATH [--config_path CONFIG_PATH]
[--batch_size BATCH_SIZE] [--seq_length SEQ_LENGTH]
[--encoder {bert,lstm,gru,cnn,gatedcnn,attn,rcnn,crnn,gpt,bilstm}]
[--bidirectional] [--subword_type {none,char}]
[--sub_vocab_path SUB_VOCAB_PATH]
[--subencoder {avg,lstm,gru,cnn}]
[--sub_layers_num SUB_LAYERS_NUM]
[--learning_rate LEARNING_RATE] [--warmup WARMUP]
[--dropout DROPOUT] [--epochs_num EPOCHS_NUM]
[--report_steps REPORT_STEPS] [--seed SEED]
```
The example of using run_ner.py:
```
python3 run_ner.py --pretrained_model_path models/google_model.bin --vocab_path models/google_vocab.txt \
--train_path datasets/msra/train.tsv --dev_path datasets/msra/dev.tsv --test_path datasets/msra/test.tsv \
--train_path datasets/msra_ner/train.tsv --dev_path datasets/msra_ner/dev.tsv --test_path datasets/msra_ner/test.tsv \
--epochs_num 5 --batch_size 32 --encoder bert
```

Expand Down
2 changes: 1 addition & 1 deletion run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def main():
parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.")

# Tokenizer options.
parser.add_argument("--tokenizer", choices=["bert", "char", "word", "space"], default="bert",
parser.add_argument("--tokenizer", choices=["bert", "char", "space"], default="bert",
help="Specify the tokenizer."
"Original Google BERT uses bert tokenizer on Chinese corpus."
"Char tokenizer segments sentences into characters."
Expand Down
2 changes: 1 addition & 1 deletion uer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def worker(gpu_id, gpu_ranks, args, model):
set_seed(args.seed)

if gpu_ranks is None:
train_loader = globals()[args.target.capitalize() + "DataLoader"](args, args.dataset_path, args.batch_size, gpu_id, 1, True)
train_loader = globals()[args.target.capitalize() + "DataLoader"](args, args.dataset_path, args.batch_size, 0, 1, True)
else:
train_loader = globals()[args.target.capitalize() + "DataLoader"](args, args.dataset_path, args.batch_size, gpu_id, len(gpu_ranks), True)

Expand Down

0 comments on commit 9d4e5f8

Please sign in to comment.