Skip to content

Commit 6d884c5

Browse files
committed
first commit
0 parents  commit 6d884c5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+209388
-0
lines changed

Diff for: .gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
.idea
2+
__pycache__
3+
THUCNews/log/
4+
THUCNews/saved_dict/*.ckpt
5+

Diff for: ERNIE_predict.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env python
2+
# -*- coding: UTF-8 -*-
3+
import os
4+
5+
import torch
6+
import torch.nn as nn
7+
from pytorch_pretrained import BertModel, BertTokenizer
8+
9+
# 识别的类型
10+
key = {
11+
0: 'finance',
12+
1: 'realty',
13+
2: 'stocks',
14+
3: 'education',
15+
4: 'science',
16+
5: 'society',
17+
6: 'politics',
18+
7: 'sports',
19+
8: 'game',
20+
9: 'entertainment'
21+
}
22+
23+
24+
class Config:
25+
"""配置参数"""
26+
27+
def __init__(self):
28+
cru = os.path.dirname(__file__)
29+
self.class_list = [str(i) for i in range(len(key))] # 类别名单
30+
self.save_path = os.path.join(cru, 'THUCNews/saved_dict/ERNIE.ckpt')
31+
self.device = torch.device('cpu')
32+
self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
33+
self.num_classes = len(self.class_list) # 类别数
34+
self.num_epochs = 3 # epoch数
35+
self.batch_size = 128 # mini-batch大小
36+
self.pad_size = 32 # 每句话处理成的长度(短填长切)
37+
self.learning_rate = 5e-5 # 学习率
38+
self.bert_path = os.path.join(cru, 'THUCNews/saved_dict/bert')
39+
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
40+
self.hidden_size = 768
41+
42+
def build_dataset(self, text):
43+
lin = text.strip()
44+
pad_size = len(lin)
45+
token = self.tokenizer.tokenize(lin)
46+
token = ['[CLS]'] + token
47+
token_ids = self.tokenizer.convert_tokens_to_ids(token)
48+
mask = [1] * pad_size
49+
token_ids = token_ids[:pad_size]
50+
return torch.tensor([token_ids], dtype=torch.long), torch.tensor([mask])
51+
52+
53+
class Model(nn.Module):
54+
55+
def __init__(self, config):
56+
super(Model, self).__init__()
57+
self.bert = BertModel.from_pretrained(config.bert_path)
58+
for param in self.bert.parameters():
59+
param.requires_grad = True
60+
self.fc = nn.Linear(config.hidden_size, config.num_classes)
61+
62+
def forward(self, x):
63+
context = x[0]
64+
mask = x[1]
65+
_, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False)
66+
out = self.fc(pooled)
67+
return out
68+
69+
70+
config = Config()
71+
model = Model(config).to(config.device)
72+
model.load_state_dict(torch.load(config.save_path, map_location='cpu'))
73+
74+
75+
def prediction_model(text):
76+
"""输入一句问话预测"""
77+
data = config.build_dataset(text)
78+
with torch.no_grad():
79+
outputs = model(data)
80+
num = torch.argmax(outputs)
81+
return key[int(num)]
82+
83+
84+
if __name__ == '__main__':
85+
print(prediction_model("备考2012高考作文必读美文50篇(一)"))

Diff for: ERNIE_pretrain/README.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## 此处存放ERNIE预训练模型:
2+
pytorch_model.bin
3+
bert_config.json
4+
vocab.txt
5+
6+
## 下载地址:
7+
http://image.nghuyong.top/ERNIE.zip

Diff for: README.md

+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Chinese-Text-Classification
2+
3+
中文文本分类,基于pytorch,开箱即用。
4+
5+
- 神经网络模型:TextCNN,TextRNN,FastText,TextRCNN,BiLSTM_Attention, DPCNN, Transformer
6+
7+
- 预训练模型:Bert,ERNIE
8+
9+
10+
11+
## 介绍
12+
13+
### 神经网络模型
14+
15+
模型介绍、数据流动过程:[参考](https://zhuanlan.zhihu.com/p/73176084)
16+
17+
数据以字为单位输入模型,预训练词向量使用 [搜狗新闻 Word+Character 300d](https://github.com/Embedding/Chinese-Word-Vectors)[点这里下载](https://pan.baidu.com/s/14k-9jsspp43ZhMxqPmsWMQ)
18+
19+
| 模型 | 介绍 |
20+
| ----------- | --------------------------------- |
21+
| TextCNN | Kim 2014 经典的CNN文本分类 |
22+
| TextRNN | BiLSTM |
23+
| TextRNN_Att | BiLSTM+Attention |
24+
| TextRCNN | BiLSTM+池化 |
25+
| FastText | bow+bigram+trigram, 效果出奇的好 |
26+
| DPCNN | 深层金字塔CNN |
27+
| Transformer | 效果较差 |
28+
29+
### 预训练模型
30+
31+
| 模型 | 介绍 | 备注 |
32+
| ---------- | ------------------------------------------------------------ | ------------ |
33+
| bert | 原始的bert | |
34+
| ERNIE | ERNIE | |
35+
| bert_CNN | bert作为Embedding层,接入三种卷积核的CNN | bert + CNN |
36+
| bert_RNN | bert作为Embedding层,接入LSTM | bert + RNN |
37+
| bert_RCNN | bert作为Embedding层,通过LSTM与bert输出拼接,经过一层最大池化层 | bert + RCNN |
38+
| bert_DPCNN | bert作为Embedding层,经过一个包含三个不同卷积特征提取器的region embedding层,可以看作输出的是embedding,然后经过两层的等长卷积来为接下来的特征抽取提供更宽的感受眼,(提高embdding的丰富性),然后会重复通过一个1/2池化的残差块,1/2池化不断提高词位的语义,其中固定了feature_maps,残差网络的引入是为了解决在训练的过程中梯度消失和梯度爆炸的问题。 | bert + DPCNN |
39+
40+
参考:
41+
42+
- [ERNIE - 详解](https://baijiahao.baidu.com/s?id=1648169054540877476)
43+
- [DPCNN 模型详解](https://zhuanlan.zhihu.com/p/372904980)
44+
- [从经典文本分类模型TextCNN到深度模型DPCNN](https://zhuanlan.zhihu.com/p/35457093)
45+
46+
## 环境
47+
python 3.7
48+
pytorch 1.1
49+
tqdm
50+
sklearn
51+
tensorboardX
52+
~~pytorch_pretrained_bert~~(预训练代码也上传了, 不需要这个库了)
53+
54+
55+
## 中文数据集
56+
我从[THUCNews](http://thuctc.thunlp.org/)中抽取了20万条新闻标题,已上传至github,文本长度在20到30之间。一共10个类别,每类2万条。数据以字为单位输入模型。
57+
58+
类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。
59+
60+
数据集划分:
61+
62+
数据集|数据量
63+
--|--
64+
训练集|18万
65+
验证集|1万
66+
测试集|1万
67+
68+
69+
### 更换数据集
70+
- 按照THUCNews数据集的格式来格式化自己的中文数据集。
71+
- 对于神经网络模型:
72+
- 如果用字,按照数据集的格式来格式化你的数据。
73+
- 如果用词,提前分好词,词之间用空格隔开,`python run.py --model TextCNN --word True`
74+
- 使用预训练词向量:utils.py的main函数可以提取词表对应的预训练词向量。
75+
76+
77+
## 实验效果
78+
79+
机器:一块2080Ti , 训练时间:30分钟。
80+
81+
模型|acc|备注
82+
--|--|--
83+
TextCNN|91.22%|Kim 2014 经典的CNN文本分类
84+
TextRNN|91.12%|BiLSTM
85+
TextRNN_Att|90.90%|BiLSTM+Attention
86+
TextRCNN|91.54%|BiLSTM+池化
87+
FastText|92.23%|bow+bigram+trigram, 效果出奇的好
88+
DPCNN|91.25%|深层金字塔CNN
89+
Transformer|89.91%|效果较差
90+
bert|94.83%|单纯的bert
91+
ERNIE|94.61%|说好的中文碾压bert呢
92+
bert_CNN|94.44%|bert + CNN
93+
bert_RNN|94.57%|bert + RNN
94+
bert_RCNN|94.51%|bert + RCNN
95+
bert_DPCNN|94.47%|bert + DPCNN
96+
97+
原始的bert效果就很好了,把bert当作embedding层送入其它模型,效果反而降了,之后会尝试长文本的效果对比。
98+
99+
## 预训练语言模型
100+
bert模型放在 bert_pretain目录下,ERNIE模型放在ERNIE_pretrain目录下,每个目录下都是三个文件:
101+
- pytorch_model.bin
102+
- bert_config.json
103+
- vocab.txt
104+
105+
预训练模型下载地址:
106+
bert_Chinese: 模型 https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz
107+
词表 https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt
108+
来自[这里](https://github.com/huggingface/pytorch-transformers)
109+
备用:模型的网盘地址:https://pan.baidu.com/s/1qSAD5gwClq7xlgzl_4W3Pw
110+
111+
ERNIE_Chinese: https://pan.baidu.com/s/1lEPdDN1-YQJmKEd_g9rLgw
112+
113+
来自[这里](https://github.com/nghuyong/ERNIE-Pytorch)
114+
115+
解压后,按照上面说的放在对应目录下,文件名称确认无误即可。
116+
117+
## 使用说明
118+
119+
### 神经网络方法
120+
121+
```
122+
# 训练并测试:
123+
# TextCNN
124+
python run.py --model TextCNN
125+
126+
# TextRNN
127+
python run.py --model TextRNN
128+
129+
# TextRNN_Att
130+
python run.py --model TextRNN_Att
131+
132+
# TextRCNN
133+
python run.py --model TextRCNN
134+
135+
# FastText, embedding层是随机初始化的
136+
python run.py --model FastText --embedding random
137+
138+
# DPCNN
139+
python run.py --model DPCNN
140+
141+
# Transformer
142+
python run.py --model Transformer
143+
```
144+
145+
### 预训练方法
146+
147+
下载好预训练模型就可以跑了:
148+
```
149+
# 预训练模型训练并测试:
150+
# bert
151+
python pretrain_run.py --model bert
152+
153+
# bert + 其它
154+
python pretrain_run.py --model bert_CNN
155+
156+
# ERNIE
157+
python pretrain_run.py --model ERNIE
158+
```
159+
160+
### 预测
161+
162+
预训练模型:
163+
164+
```
165+
# bert (+其他)
166+
python bert_predict.py
167+
168+
# ERNIE
169+
python ERNIE_predict.py
170+
```
171+
172+
神经网络模型:
173+
174+
```
175+
Todo:
176+
```
177+
178+
179+
### 参数
180+
模型都在models目录下,超参定义和模型定义在同一文件中。
181+
182+
## 参考
183+
184+
### 论文
185+
186+
[1] Convolutional Neural Networks for Sentence Classification
187+
188+
[2] Recurrent Neural Network for Text Classification with Multi-Task Learning
189+
190+
[3] Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification
191+
192+
[4] Recurrent Convolutional Neural Networks for Text Classification
193+
194+
[5] Bag of Tricks for Efficient Text Classification
195+
196+
[6] Deep Pyramid Convolutional Neural Networks for Text Categorization
197+
198+
[7] Attention Is All You Need
199+
200+
[8] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
201+
202+
[9] ERNIE: Enhanced Representation through Knowledge Integration
203+
204+
### 仓库
205+
206+
本项目基于以下仓库继续开发优化:
207+
208+
- https://github.com/649453932/Chinese-Text-Classification-Pytorch
209+
- https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch
210+

Diff for: THUCNews/data/class.txt

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
finance
2+
realty
3+
stocks
4+
education
5+
science
6+
society
7+
politics
8+
sports
9+
game
10+
entertainment

0 commit comments

Comments
 (0)