-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathget_question_entity.py
More file actions
58 lines (44 loc) · 2.08 KB
/
get_question_entity.py
File metadata and controls
58 lines (44 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from typing import List, Dict
from JudgeAgent import *
from JudgeAgent.label_entity import label_entity_for_texts
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="MedQA")
parser.add_argument("--model", type=str, default="gpt")
args = parser.parse_args()
data_name: str = args.data
data_dir = os.path.join("processed_data", data_name)
save_path = os.path.join(data_dir, "question_with_entities.json")
questions = load_json(os.path.join(data_dir, "questions.json"))
question_with_entities: List[Dict] = load_json(save_path) if os.path.exists(save_path) else []
client = LLMClient(MODEL_PARAMS[args.model])
from tqdm import tqdm
with tqdm(total=len(questions), desc="label entity on question") as pbar:
index = len(question_with_entities)
pbar.update(index)
if data_name.lower() == "quality":
for qdata in questions[index : ]:
new_questions = []
for q in qdata["questions"]:
question, area = q["question"], q["area"]
labeled_entities = label_entity_for_texts([question], client, area)
entities = labeled_entities[0]
new_questions.append({**q, **{"entities": entities}})
question_with_entities.append({
"quetions": new_questions,
"article": qdata["article"]
})
index += 1
dump_json(question_with_entities, save_path)
pbar.update(1)
else:
for qdata in questions[index : ]:
question, area = qdata["question"], qdata["area"]
labeled_entities = label_entity_for_texts([question], client, area)
entities = labeled_entities[0]
question_with_entities.append({**qdata, **{"entities": entities}})
index += 1
dump_json(question_with_entities, save_path)
pbar.update(1)