-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdriver.py
127 lines (96 loc) · 4.23 KB
/
driver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import sys
import grpc
import asyncio
import server_pb2
import server_pb2_grpc
import signal
from subprocess import Popen
import argparse
import time
import os
from multiprocessing.connection import Client
from dotenv import load_dotenv
import pathlib
load_dotenv()
print(str(pathlib.Path(__file__).parent.resolve()) + "/ColBERT")
sys.path.append(str(pathlib.Path(__file__).parent.resolve()) + "/ColBERT")
from colbert.data import Queries
def save_rankings(rankings, filename):
output = []
for q in rankings:
for result in q.topk:
output.append("\t".join([str(x) for x in [q.qid, result.pid, result.rank, result.score]]))
f = open(filename, "w")
f.write("\n".join(output))
f.close()
async def run_request(stub, request, experiment):
t = time.time()
if experiment == "search":
out = await stub.Search(request)
elif experiment == "pisa":
out = await stub.Pisa(request)
else:
out = await stub.Serve(request)
return out, time.time() - t
async def run(args):
queries = Queries(path=f"{os.environ['DATA_PATH']}/{args.index}/questions.tsv")
qvals = list(queries.items())
tasks = []
stub = server_pb2_grpc.ServerStub(grpc.aio.insecure_channel('localhost:50050'))
inter_request_time = [float(x) for x in open(args.timings).read().split("\n") if x != ""]
length = len(inter_request_time)
# Warmup
for i in range(len(qvals)-100, len(qvals)):
request = server_pb2.Query(query=qvals[i][1], qid=qvals[i][0], k=100)
tasks.append(asyncio.ensure_future(run_request(stub, request, args.experiment)))
await asyncio.sleep(0)
await asyncio.gather(*tasks)
await stub.DumpScores(server_pb2.Empty())
tasks = []
t = time.time()
for i in range(len(qvals)):
request = server_pb2.Query(query=qvals[i][1], qid=qvals[i][0], k=100)
tasks.append(asyncio.ensure_future(run_request(stub, request, args.experiment)))
await asyncio.sleep(inter_request_time[i % length])
await asyncio.sleep(0)
ret = list(zip(*await asyncio.gather(*tasks)))
save_rankings(ret[0], args.ranking_file)
total_time = str(time.time()-t)
open(args.output, "w").write("\n".join([str(x) for x in ret[1]]) + f"\nTotal time: {total_time}")
print(f"Total time for {len(qvals)} requests:", total_time)
await stub.DumpScores(server_pb2.Empty())
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluator for ColBERT')
parser.add_argument('-w', '--num_workers', type=int, required=True,
help='Number of worker threads per server')
parser.add_argument('-o', '--output', type=str, required=True,
help='Output file to save results')
parser.add_argument('-r', '--ranking_file', type=str, required=True,
help='Output file to save rankings')
parser.add_argument('-t', '--timings', type=str, required=True,
help='Input file for inter request wait times')
parser.add_argument('-e', '--experiment', type=str, default="search", choices=["search", "pisa", "serve"],
help='search or pisa or serve (pisa + rerank)')
parser.add_argument('-i', '--index', type=str, required=True, help='Index to run (use "wiki", "msmarco", "lifestyle" to repro the paper, or specify your own index name)')
parser.add_argument('-m', '--mmap', action="store_true", help='If the index is memory mapped')
args = parser.parse_args()
arg_str = f"-w {args.num_workers} -i {args.index} -r driver"
if args.mmap:
arg_str += " -m"
process = Popen(["python", "server.py"] + f"{arg_str}".split(" "))
times = 10
for i in range(times):
try:
connection = Client(('localhost', 50040), authkey=b'password')
assert connection.recv() == "Done"
connection.close()
break
except ConnectionRefusedError:
if i == times - 1:
print("Failed to receive connection for child server. Terminating!")
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
sys.exit(-1)
time.sleep(5)
asyncio.run(run(args))
print("Killing processing after completion")
process.kill()