-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_qa_roberta.py.bak
70 lines (52 loc) · 2.13 KB
/
predict_qa_roberta.py.bak
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import pandas as pd
import json
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
from flask import Flask, request
import multiprocessing as mp
import concurrent.futures
app = Flask(__name__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# @app.route('/qa', methods=['POST'])
# def predict():
# 加载预训练模型
model = AutoModelForQuestionAnswering.from_pretrained('./model/chinese_pretrain_mrc_roberta_wwm_ext_large')
tokenizer = AutoTokenizer.from_pretrained('./model/chinese_pretrain_mrc_roberta_wwm_ext_large')
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device)
# sentence = request.form.get("key")
# 加载数据集
# 使用pipeline方法进行预测
def qa(data):
results = qa_pipeline(data)
return results
# 使用进程池并行处理数据集中的每个问题并设置chunsize使用较小的批处理大小
# with mp.Pool() as pool:
# results = pool.imap(qa, dataset,chunksize=10)
# 使用线程池并行处理数据集中的每个问题并设置chunsize使用较小的批处理大小
@app.route("/qaroberta")
def run():
file_path = request.form.get("file_path")
df = pd.read_json(file_path)
# 转换数据集格式
dataset = ({'context': q['context'], 'question': k['question']} for p in df['data'] for q in p['paragraphs'] for k
in
q['qas'])
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as pool:
results = pool.map(qa, dataset, chunksize=10)
return_result = {}
for idx, result in enumerate(results):
if result["answer"] is not None:
return_result["answer"] = result["answer"]
return json.dumps(return_result)
# 将结果保存到文件
def result(results):
output = []
for idx, result in enumerate(results):
output.append({"idx": idx + 1, "answer": result["answer"], "score": result["score"]})
with open("results.json", 'w', encoding="utf-8") as f:
json.dump(output, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
app.run(
host='0.0.0.0',
port=6005
)