-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_large_graph.py
More file actions
56 lines (48 loc) · 2.14 KB
/
Copy pathinference_large_graph.py
File metadata and controls
56 lines (48 loc) · 2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from src.dataset.large_graph import LargeGraphDataset
from src.model.text_graph_llm import TextOnlyGraphLLM
from torch.utils.data import DataLoader
from src.utils.collate import collate_fn
import argparse
import json
import os
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--graph_path", type=str, required=True)
parser.add_argument("--question", type=str, required=True, help="Question text")
parser.add_argument("--openai_api_key", type=str, required=True, help="OpenAI API key")
parser.add_argument("--vector_db_path", type=str, required=True, help="Path to vector DB")
parser.add_argument("--vector_db_collection", type=str, default=None, help="Vector DB collection name")
parser.add_argument("--output_path", type=str, default="outputs/results.json")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_txt_len", type=int, default=2048)
parser.add_argument("--max_new_tokens", type=int, default=512)
return parser.parse_args()
def main():
args = parse_args()
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
# Initialize dataset with question text
dataset = LargeGraphDataset(args.graph_path, args.question)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
# Initialize model
model = TextOnlyGraphLLM(args)
model.eval()
# Run inference
results = []
for batch in tqdm(dataloader):
outputs = model.inference(batch)
for i in range(len(outputs['id'])):
results.append({
'id': outputs['id'][i],
'question': outputs['question'][i],
'prediction': outputs['pred'][i],
'subgraph_desc': outputs['desc'][i],
# 'original_graph': outputs['original_graph'][i],
'vector_context': outputs['vector_context'][i]
})
# Save results
with open(args.output_path, 'w') as f:
json.dump(results, f, indent=2)
if __name__ == "__main__":
main()