-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
9,065 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# 基于bert4keras的抽取式MRC基准代码 | ||
## 简介 | ||
本仓库是基于BERT4keras的抽取式MRC问答基础代码。详细介绍请看博客:https://kexue.fm/archives/8739 | ||
|
||
## 文件介绍 | ||
### datasets | ||
|
||
该文件夹存放的是抽取式MRC数据集,分为训练集(train.json)、验证集(dev.json)以及测试集(test.json)。训练集和验证集是对模型在 下游任务中进行微调,使其可以学习到该领域的数据特征,模型在训练集和验证集上训练完成后,会生成一个模型的权重信息即xxx.weights。然后通过使用模型生成的权重信息,在测试集上进行相应的测试。 | ||
|
||
训练集的格式如下图所示: | ||
|
||
 | ||
|
||
### model | ||
|
||
该文件存放的是预训练模型,可根据自己需要选择相应的预训练模型,其中有bert、Roberta以及wwm等。 | ||
|
||
### src | ||
|
||
该文件下存放的是源代码,其中`cmrc2018.py`是实现抽取式MRC的源代码,主要包含以下几部分: | ||
|
||
- 加载数据集,生成每个batch_size的数据,主要由`load_data`函数和`data_generator`类进行实现 | ||
|
||
- 构建模型,该部分在代码有注释说明,一般情况下不改变模型的网络结构就不需要进行更改。 | ||
- 开始训练模型,并保存验证集中准确度最好的模型 | ||
- 测试模型效果 | ||
|
||
`snippets.py`是模型的配置文件,用来配置数据集的路径、模型的通用参数以及预训练模型的路径等。 | ||
|
||
`cmrc2018_evaluate.py`是用来测试模型生成答案的EM指标和F1指标。 | ||
|
||
`weights`文件夹用来保存模型生成的权重信息。 | ||
|
||
`results`文件夹用来保存模型测试生成的答案。 | ||
|
||
## 使用步骤 | ||
|
||
- 配置环境。需要的包已经列在`pip_requirements.txt`和`conda_requirements.txt`中。 | ||
|
||
``` | ||
conda create -n bert4keras python=3.6 | ||
source activate bert4keras | ||
``` | ||
|
||
- 下载预训练模型。这里提供一个基础的[BERT模型下载链接](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)。将其下载放入model文件夹中。您也可以下载其他的预训练模型。 | ||
|
||
- 准备好数据集。请注意,如果您需要在下游任务中进行微调,请准备好train.json文件和dev.json文件。 | ||
|
||
例如,您需要在反恐领域运行该代码,并且期望模型表现较好,您需要首先准备好反恐领域的文本,每一条文本数据的长度介于150字到900字之间,然后针对于每条文本,提出三到五个问题,并在文本中找出相应的答案,同时给出答案首次出现在文中的序号。所有的文本都标注完成之后,将其处理成相应的数据集格式即可。 | ||
|
||
如果您只想让模型通过问题和文本,预测出相应的答案,您只需要准备好test.json文件夹,然后在通用领域抽取式阅读理解数据集(如:cmrc2018)上进行微雕,保存相应的权重信息。最后使用模型进行预测即可。 | ||
|
||
- 运行代码 | ||
|
||
``` | ||
python cmrc2018.py | ||
``` | ||
|
||
## 环境 | ||
- 软件:bert4keras>=0.10.8,具体请看`pip_requirements.txt`和`conda_requirements.txt` | ||
- 硬件:显存不够,可以适当降低batch_size,如果有多GPU,可以开启多GPU进行训练 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import pandas as pd | ||
import json | ||
import torch | ||
from transformers import AutoTokenizer,AutoModelForQuestionAnswering,pipeline | ||
from flask import Flask, request | ||
import multiprocessing as mp | ||
import concurrent.futures | ||
|
||
app = Flask(__name__) | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
# @app.route('/qa', methods=['POST']) | ||
# def predict(): | ||
# 加载预训练模型 | ||
model = AutoModelForQuestionAnswering.from_pretrained('./model/chinese_pretrain_mrc_roberta_wwm_ext_large') | ||
tokenizer = AutoTokenizer.from_pretrained('./model/chinese_pretrain_mrc_roberta_wwm_ext_large') | ||
qa_pipeline = pipeline("question-answering",model=model,tokenizer=tokenizer,device=device) | ||
|
||
# sentence = request.form.get("key") | ||
# 加载数据集 | ||
df = pd.read_json("datasets/contend.json") | ||
|
||
# 转换数据集格式 | ||
dataset = [] | ||
|
||
# 创建一个生成器表达式 | ||
|
||
# dataset = ({'context': p['context'], 'question': q['question']} for p in df['data'] for q in p['paragraphs'] for q in p['qas']) | ||
|
||
dataset = ({'context': q['context'], 'question': k['question']} for p in df['data'] for q in p['paragraphs'] for k in q['qas']) | ||
|
||
|
||
# 使用pipeline方法进行预测 | ||
def qa(data): | ||
results = qa_pipeline(data) | ||
return results | ||
|
||
# 使用进程池并行处理数据集中的每个问题并设置chunsize使用较小的批处理大小 | ||
# with mp.Pool() as pool: | ||
# results = pool.imap(qa, dataset,chunksize=10) | ||
|
||
# 使用线程池并行处理数据集中的每个问题并设置chunsize使用较小的批处理大小 | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as pool: | ||
results = pool.map(qa, dataset,chunksize=10) | ||
|
||
# 将结果保存到文件 | ||
output = [] | ||
for idx,result in enumerate(results): | ||
output.append({"idx":idx+1,"answer":result["answer"],"score":result["score"]}) | ||
|
||
with open("results.json",'w',encoding="utf-8") as f: | ||
json.dump(output,f,ensure_ascii=False,indent=4) | ||
|
||
# if __name__ == '__main__': | ||
# app.run( | ||
# host='0.0.0.0', | ||
# port=6000 | ||
# ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#-*- coding:utf-8 -*- | ||
import pandas as pd | ||
import json | ||
import torch | ||
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline | ||
from flask import Flask, request | ||
import multiprocessing as mp | ||
import concurrent.futures | ||
import logging | ||
|
||
app = Flask(__name__) | ||
app.logger.setLevel(logging.DEBUG) | ||
|
||
# 创建一个日志处理器,输出到控制台 | ||
handler = logging.StreamHandler() | ||
handler.setLevel(logging.DEBUG) | ||
# 设置日志格式 | ||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
handler.setFormatter(formatter) | ||
|
||
# 将日志处理器添加到应用程序日志处理器列表中 | ||
app.logger.addHandler(handler) | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
# @app.route('/qa', methods=['POST']) | ||
# def predict(): | ||
# 加载预训练模型 | ||
model = AutoModelForQuestionAnswering.from_pretrained('./model/chinese_pretrain_mrc_roberta_wwm_ext_large') | ||
tokenizer = AutoTokenizer.from_pretrained('./model/chinese_pretrain_mrc_roberta_wwm_ext_large') | ||
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device) | ||
|
||
# sentence = request.form.get("key") | ||
# 加载数据集 | ||
|
||
|
||
|
||
# 使用pipeline方法进行预测 | ||
def qa(data): | ||
results = qa_pipeline(data) | ||
return results | ||
|
||
|
||
# 使用进程池并行处理数据集中的每个问题并设置chunsize使用较小的批处理大小 | ||
# with mp.Pool() as pool: | ||
# results = pool.imap(qa, dataset,chunksize=10) | ||
|
||
# 使用线程池并行处理数据集中的每个问题并设置chunsize使用较小的批处理大小 | ||
@app.route("/qarob",methods=["POST"]) | ||
def run(): | ||
file_path = request.form.get("in_file") | ||
df = pd.read_json(file_path) | ||
# 转换数据集格式 | ||
dataset = ({'context': q['context'], 'question': k['question']} for p in df['data'] for q in p['paragraphs'] for k | ||
in | ||
q['qas']) | ||
# 创建一个包含所有id的列表 | ||
ids = (k['id'] for p in df['data'] for q in p['paragraphs'] for k in q['qas']) | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as pool: | ||
results = pool.map(qa, dataset, chunksize=10) | ||
return_result = {} | ||
for idx, result in zip(ids, results): | ||
if result["answer"] is not None: | ||
if result["score"] > 0.00: | ||
return_result[idx] = result["answer"] | ||
else: | ||
return_result[idx] = "未找到相关信息" | ||
app.logger.info(result["score"]) | ||
app.logger.info(idx) | ||
app.logger.info(return_result) | ||
return json.dumps(return_result,ensure_ascii=False) | ||
|
||
|
||
|
||
# 将结果保存到文件 | ||
|
||
def result_(results): | ||
output = [] | ||
for idx, result in enumerate(results): | ||
output.append({"idx": idx + 1, "answer": result["answer"], "score": result["score"]}) | ||
|
||
with open("results.json", 'w', encoding="utf-8") as f: | ||
json.dump(output, f, ensure_ascii=False, indent=4) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run( | ||
host='0.0.0.0', | ||
port=6005 | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pandas as pd | ||
import json | ||
import torch | ||
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline | ||
from flask import Flask, request | ||
import multiprocessing as mp | ||
import concurrent.futures | ||
|
||
app = Flask(__name__) | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
# @app.route('/qa', methods=['POST']) | ||
# def predict(): | ||
# 加载预训练模型 | ||
model = AutoModelForQuestionAnswering.from_pretrained('./model/chinese_pretrain_mrc_roberta_wwm_ext_large') | ||
tokenizer = AutoTokenizer.from_pretrained('./model/chinese_pretrain_mrc_roberta_wwm_ext_large') | ||
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device) | ||
|
||
# sentence = request.form.get("key") | ||
# 加载数据集 | ||
|
||
|
||
|
||
# 使用pipeline方法进行预测 | ||
def qa(data): | ||
results = qa_pipeline(data) | ||
return results | ||
|
||
|
||
# 使用进程池并行处理数据集中的每个问题并设置chunsize使用较小的批处理大小 | ||
# with mp.Pool() as pool: | ||
# results = pool.imap(qa, dataset,chunksize=10) | ||
|
||
# 使用线程池并行处理数据集中的每个问题并设置chunsize使用较小的批处理大小 | ||
@app.route("/qaroberta") | ||
def run(): | ||
file_path = request.form.get("file_path") | ||
df = pd.read_json(file_path) | ||
# 转换数据集格式 | ||
dataset = ({'context': q['context'], 'question': k['question']} for p in df['data'] for q in p['paragraphs'] for k | ||
in | ||
q['qas']) | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as pool: | ||
results = pool.map(qa, dataset, chunksize=10) | ||
return_result = {} | ||
for idx, result in enumerate(results): | ||
if result["answer"] is not None: | ||
return_result["answer"] = result["answer"] | ||
return json.dumps(return_result) | ||
|
||
|
||
|
||
# 将结果保存到文件 | ||
|
||
def result(results): | ||
output = [] | ||
for idx, result in enumerate(results): | ||
output.append({"idx": idx + 1, "answer": result["answer"], "score": result["score"]}) | ||
|
||
with open("results.json", 'w', encoding="utf-8") as f: | ||
json.dump(output, f, ensure_ascii=False, indent=4) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run( | ||
host='0.0.0.0', | ||
port=6005 | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{ | ||
"version" : null, | ||
"data" : [ { | ||
"paragraphs" : [ { | ||
"id" : null, | ||
"context": "名称湖北景深安全技术有限公司注册资本壹纤万圆整类型有限责任公司(自然人投资或控股)成立日期2007年07月13日法定代表人黄兆云营业期限长期经营范围净可项目:安全评价业务:安全生产检验检测,检验检测服务:住职业卫生技术课务:空内环填检测:消防技术服务(依法须经械所宜昌市西陵区滑河四路86号准的项目,经相关密门设准后方可开展经营活动,其体经我项目以相关部准文件或许可证件为准)股项目:安全咨询服务,标准化服务,平境保护监润:环保咨淘服务:主第环现污染防治里务:大利相关咨询服务:信息系统集成服务:网终与信息安全款样开发:款件开发:安全技本防范系统设计工服务:信易安全设销售:电子产品销售(除许司)业务外,可自主依法经营法神法规非禁止或限制的项门)登记机关2021年5月18日营业执照统一社会信用代码二作科骨泵国91420500662296752N全业EN7家更便记各室,奔用速能信2,名称湖北景深安全技术有限公司注册资本壹竹万圆整类型有限责任公司(自然人投资或控股)成立日期2007年07月13日法定代表人黄兆云营业期限长期经营范n本可项目:安全评价业务:安全生产检验检测:检验检测服务:住职业卫生技术票务:室内环境检测:消防技术服务(依法须经批所宜昌市西陵区滑河四路86号准的项目,经相关密门批准品方可开展经营活动,具体经营项目以相关部(批准文件或许可证件为准)股项目:安全咨询赚务:标准化服务:平境保护监润:环保咨询服务:主靠环境污染防治服务:水利相关咨询服务:信息系统集成服务:网终与信息安全软件开发:软性开发:安全技本防范系统设计族工服务:信息安全设备销售:电子产品销售(除许可业务外,可自主依法经营法神法规非禁止或限制的项门)登记机关2021年0月18日N国家市场监督产理总监制", | ||
"qas" : [ { | ||
"question" : "申请日期是什么时候?用yyyy-MM-dd的格式回答,比如2022-01-28,不按格式给零分", | ||
"id" : "q1", | ||
"answers" : [ { | ||
"text" : "", | ||
"answer_start" : 0 | ||
} ] | ||
}, | ||
{ | ||
"question" : "这段文本含有长期两个字吗", | ||
"id" : "q2", | ||
"answers" : [ { | ||
"text" : "", | ||
"answer_start" : 0 | ||
} ] | ||
} | ||
|
||
] | ||
} ], | ||
"id" : null, | ||
"title" : null | ||
} ] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
certifi @ file:///croot/certifi_1671487769961/work/certifi | ||
charset-normalizer==3.1.0 | ||
click==8.1.3 | ||
filelock==3.10.7 | ||
Flask==2.2.3 | ||
huggingface-hub==0.13.3 | ||
idna==3.4 | ||
importlib-metadata==6.1.0 | ||
itsdangerous==2.1.2 | ||
Jinja2==3.1.2 | ||
MarkupSafe==2.1.2 | ||
numpy==1.24.2 | ||
packaging==23.0 | ||
pandas==1.5.3 | ||
Pillow==9.4.0 | ||
python-dateutil==2.8.2 | ||
pytz==2023.3 | ||
PyYAML==6.0 | ||
regex==2023.3.23 | ||
requests==2.28.2 | ||
sentencepiece==0.1.97 | ||
six==1.16.0 | ||
tokenizers==0.13.2 | ||
torch==1.8.1+cu102 | ||
torchaudio==0.8.1 | ||
torchvision==0.9.1+cu102 | ||
tqdm==4.65.0 | ||
transformers==4.27.3 | ||
typing_extensions==4.5.0 | ||
urllib3==1.26.15 | ||
Werkzeug==2.2.3 | ||
zipp==3.15.0 |
Oops, something went wrong.