Skip to content

Commit 47821b5

Browse files
author
Changxf5
committed
update file format
1 parent 0b86437 commit 47821b5

2 files changed

Lines changed: 11 additions & 19 deletions

File tree

router_inference/router/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
from router_inference.router.vllm_sr import VLLMSR
99
from router_inference.router.api_router import APIRouter
1010
from router_inference.router.auto_router import auto_router
11+
1112
__all__ = ["BaseRouter", "ExampleRouter", "VLLMSR", "APIRouter", "auto_router"]

router_inference/router/auto_router.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,37 +38,28 @@ def __init__(self, router_name: str):
3838

3939
def get_routing_result(self, query):
4040
"""Test the routing API with the specified request"""
41-
url = 'http://10.109.4.26:8501/api/v1/routing'
42-
43-
headers = {
44-
'Content-Type': 'application/json'
45-
}
41+
url = "http://10.109.4.26:8501/api/v1/routing"
42+
43+
headers = {"Content-Type": "application/json"}
4644

4745
data = {
4846
"default": 287,
4947
"models": self.id_list,
50-
"messages": [
51-
{
52-
"role": "user",
53-
"content": query
54-
}
55-
],
48+
"messages": [{"role": "user", "content": query}],
5649
"strategy": 1,
5750
"qualityPercentage": 60,
5851
"option": {
5952
"multiRoundJudgeMode": 0,
6053
"enableRagWebJudge": False,
6154
"enableDomainJudge": False,
62-
}
55+
},
56+
}
6357

64-
}
65-
6658
# Send POST request
6759
response = httpx.post(url, headers=headers, json=data)
68-
6960

70-
model_name = self.id_to_modelname[response.json()['data']['id']]
71-
61+
model_name = self.id_to_modelname[response.json()["data"]["id"]]
62+
7263
return model_name
7364

7465
def _get_prediction(self, query: str) -> str:
@@ -85,9 +76,9 @@ def _get_prediction(self, query: str) -> str:
8576
Returns:
8677
Name of the selected model
8778
"""
88-
79+
8980
# Simple example: cycle through models
9081
model_name = self.get_routing_result(query)
91-
print('----------------model_name:',model_name)
82+
print("----------------model_name:", model_name)
9283
self.counter += 1
9384
return model_name

0 commit comments

Comments
 (0)