-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathrouter.py
159 lines (131 loc) · 5.33 KB
/
router.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import random
from typing import Any, Callable, Dict, List
import numpy as np
from ranx import Qrels, Run, evaluate
from redisvl.extensions.router.semantic import SemanticRouter
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
from redisvl.utils.optimize.schema import LabeledData
from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels
def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -> Run:
"""Format router results into format for ranx Run"""
run_dict: Dict[Any, Any] = {}
for td in test_data:
run_dict[td.id] = {}
route_match = router(td.query)
if route_match and route_match.name == td.query_match:
run_dict[td.id][td.query_match] = np.int64(1)
else:
run_dict[td.id][NULL_RESPONSE_KEY] = np.int64(1)
return Run(run_dict)
def _eval_router(
router: SemanticRouter, test_data: List[LabeledData], qrels: Qrels, eval_metric: str
) -> float:
"""Evaluate acceptable metric given run and qrels data"""
run = _generate_run_router(test_data, router)
return evaluate(qrels, run, eval_metric, make_comparable=True)
def _router_random_search(
route_names: List[str], route_thresholds: dict, search_step=0.10
):
"""Performs random search for many thresholds to many routes"""
score_threshold_values = []
for route in route_names:
score_threshold_values.append(
np.linspace(
start=max(route_thresholds[route] - search_step, 0),
stop=route_thresholds[route] + search_step,
num=100,
)
)
return {
route: float(random.choice(score_threshold_values[i]))
for i, route in enumerate(route_names)
}
def _random_search_opt_router(
router: SemanticRouter,
test_data: List[LabeledData],
qrels: Qrels,
eval_metric: EvalMetric,
**kwargs: Any,
):
"""Performs complete optimization for router cases provide acceptable metric"""
start_score = _eval_router(router, test_data, qrels, eval_metric.value)
best_score = start_score
best_thresholds = router.route_thresholds
max_iterations = kwargs.get("max_iterations", 20)
search_step = kwargs.get("search_step", 0.10)
for _ in range(max_iterations):
route_names = router.route_names
route_thresholds = router.route_thresholds
thresholds = _router_random_search(
route_names=route_names,
route_thresholds=route_thresholds,
search_step=search_step,
)
router.update_route_thresholds(thresholds)
score = _eval_router(router, test_data, qrels, eval_metric.value)
if score > best_score:
best_score = score
best_thresholds = thresholds
print(
f"Eval metric {eval_metric.value.upper()}: start {round(start_score, 3)}, end {round(best_score, 3)} \nEnding thresholds: {router.route_thresholds}"
)
router.update_route_thresholds(best_thresholds)
class RouterThresholdOptimizer(BaseThresholdOptimizer):
"""
Class for optimizing thresholds for a SemanticRouter.
.. code-block:: python
from redisvl.extensions.router import Route, SemanticRouter
from redisvl.utils.vectorize import HFTextVectorizer
from redisvl.utils.optimize import RouterThresholdOptimizer
routes = [
Route(
name="greeting",
references=["hello", "hi"],
metadata={"type": "greeting"},
distance_threshold=0.5,
),
Route(
name="farewell",
references=["bye", "goodbye"],
metadata={"type": "farewell"},
distance_threshold=0.5,
),
]
router = SemanticRouter(
name="greeting-router",
vectorizer=HFTextVectorizer(),
routes=routes,
redis_url="redis://localhost:6379",
overwrite=True # Blow away any other routing index with this name
)
test_data = [
{"query": "hello", "query_match": "greeting"},
{"query": "goodbye", "query_match": "farewell"},
...
]
optimizer = RouterThresholdOptimizer(router, test_data)
optimizer.optimize()
"""
def __init__(
self,
router: SemanticRouter,
test_dict: List[Dict[str, Any]],
opt_fn: Callable = _random_search_opt_router,
eval_metric: str = "f1",
):
"""Initialize the router optimizer.
Args:
router (SemanticRouter): The RedisVL SemanticRouter instance to optimize.
test_dict (List[Dict[str, Any]]): List of test cases.
opt_fn (Callable): Function to perform optimization. Defaults to
grid search.
eval_metric (str): Evaluation metric for threshold optimization.
Defaults to "f1" score.
Raises:
ValueError: If the test_dict not in LabeledData format.
"""
super().__init__(router, test_dict, opt_fn, eval_metric)
def optimize(self, **kwargs: Any):
"""Optimize kicks off the optimization process for router"""
qrels = _format_qrels(self.test_data)
self.opt_fn(self.optimizable, self.test_data, qrels, self.eval_metric, **kwargs)