-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconstruct_graph.py
More file actions
84 lines (60 loc) · 3 KB
/
construct_graph.py
File metadata and controls
84 lines (60 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import List, Dict, Set
from JudgeAgent.label_entity import Entity, label_entity_for_batch
from JudgeAgent import *
def load_entity_dict(path: str) -> Dict[int, Dict[int, Set[Entity]]]:
entity_dict: Dict[int, Dict[int, Set[Entity]]] = {}
if os.path.exists(path):
data: Dict[int, Dict[int, Set[Dict]]] = load_json(path)
for pid, pdict in data.items():
entity_dict[pid] = {}
for sid, sentities in pdict.items():
entity_dict[pid][sid] = set([Entity.from_dict(se) for se in sentities])
return entity_dict
def save_entity_dict(entity_dict: Dict[int, Dict[int, Set[Entity]]], path: str):
save_dict: Dict[int, Dict[int, List[Dict]]] = {}
for pid, pdict in entity_dict.items():
save_dict[pid] = {}
for sid, sentities in pdict.items():
save_dict[pid][sid] = [se.to_dict() for se in sentities]
dump_json(save_dict, path)
def refresh_entity_dict(base_dict: Dict[int, Dict[int, Set[Entity]]], add_dict: Dict[int, Dict[int, List[Entity]]]) -> Dict[int, Dict[int, Set[Entity]]]:
for pid, pdict in add_dict.items():
if pid not in base_dict:
base_dict[pid] = {}
for sid, sentities in pdict.items():
if sid not in base_dict[pid]:
base_dict[pid][sid] = set(sentities)
else:
base_dict[pid][sid].update(sentities)
return base_dict
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="MedQA")
parser.add_argument("--model", type=str, default="gpt")
parser.add_argument("--bs", type=int, default=10)
args = parser.parse_args()
data_name: str = args.data
batch_size: int = args.bs
data_dir = os.path.join("processed_data", data_name)
graph = Graph(data_dir)
chunks = load_jsonl(os.path.join(data_dir, "corpus_chunks.jsonl"))
batches = chunk_to_batch(chunks, batch_size=batch_size)
progress_path = os.path.join("temp_progress", data_name, "label_entity_progress_index.json")
index = load_json(progress_path) if os.path.exists(progress_path) else {"index": 0}
entity_dict_path = os.path.join(data_dir, "entity_dict.json")
entity_dict: Dict[int, Dict[int, Set[Entity]]] = load_entity_dict(entity_dict_path)
client = LLMClient(MODEL_PARAMS[args.model])
from tqdm import tqdm
with tqdm(total=len(batches), desc="label entity") as pbar:
pbar.update(index["index"])
for batch_data in batches[index["index"] : ]:
batch, area = batch_data["batch"], batch_data["area"]
temp_entity_dict = label_entity_for_batch(batch, client, area, batch_size=batch_size)
entity_dict = refresh_entity_dict(entity_dict, temp_entity_dict)
save_entity_dict(entity_dict, entity_dict_path)
index["index"] += 1
dump_json(index, progress_path)
pbar.update(1)
graph.construct_graph_from_entity_dict(entity_dict)
graph.save_graph()