Skip to content

Commit 64e5618

Browse files
authored
Merge pull request #35 from predictionguard/jacob/rerank
adding rerank function
2 parents a4678a7 + 855ca09 commit 64e5618

File tree

8 files changed

+166
-6
lines changed

8 files changed

+166
-6
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
TEST_TEXT_EMBEDDINGS_MODEL: ${{ secrets.TEST_TEXT_EMBEDDINGS_MODEL }}
2626
TEST_MULTIMODAL_EMBEDDINGS_MODEL: ${{ secrets.TEST_MULTIMODAL_EMBEDDINGS_MODEL }}
2727
TEST_VISION_MODEL: ${{ secrets.TEST_VISION_MODEL }}
28+
TEST_RERANK_MODEL: ${{ secrets.TEST_RERANK_MODEL }}
2829

2930
- name: To PyPI using Flit
3031
uses: AsifArmanRahman/to-pypi-using-flit@v1

.github/workflows/pr.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ jobs:
2828
TEST_MODEL_NAME: ${{ secrets.TEST_MODEL_NAME }}
2929
TEST_TEXT_EMBEDDINGS_MODEL: ${{ secrets.TEST_TEXT_EMBEDDINGS_MODEL }}
3030
TEST_MULTIMODAL_EMBEDDINGS_MODEL: ${{ secrets.TEST_MULTIMODAL_EMBEDDINGS_MODEL }}
31-
TEST_VISION_MODEL: ${{ secrets.TEST_VISION_MODEL }}
31+
TEST_VISION_MODEL: ${{ secrets.TEST_VISION_MODEL }}
32+
TEST_RERANK_MODEL: ${{ secrets.TEST_RERANK_MODEL }}

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,6 @@ venv.bak/
104104
# mypy
105105
.mypy_cache/
106106

107+
# JetBrains Folder
108+
.idea
109+

predictionguard/client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .src.chat import Chat
77
from .src.completions import Completions
88
from .src.embeddings import Embeddings
9+
from .src.rerank import Rerank
910
from .src.tokenize import Tokenize
1011
from .src.translate import Translate
1112
from .src.factuality import Factuality
@@ -16,8 +17,9 @@
1617
from .version import __version__
1718

1819
__all__ = [
19-
"PredictionGuard", "Chat", "Completions", "Embeddings", "Tokenize",
20-
"Translate", "Factuality", "Toxicity", "Pii", "Injection", "Models"
20+
"PredictionGuard", "Chat", "Completions", "Embeddings", "Rerank",
21+
"Tokenize", "Translate", "Factuality", "Toxicity", "Pii", "Injection",
22+
"Models"
2123
]
2224

2325
class PredictionGuard:
@@ -63,6 +65,9 @@ def __init__(
6365
self.embeddings: Embeddings = Embeddings(self.api_key, self.url)
6466
"""Embedding generates chat completions based on a conversation history."""
6567

68+
self.rerank: Rerank = Rerank(self.api_key, self.url)
69+
"""Rerank sorts text inputs by semantic relevance to a specified query."""
70+
6671
self.translate: Translate = Translate(self.api_key, self.url)
6772
"""Translate converts text from one language to another."""
6873

predictionguard/src/rerank.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import json
2+
3+
import requests
4+
from typing import Any, Dict, List, Optional
5+
6+
from ..version import __version__
7+
8+
9+
class Rerank:
10+
"""Rerank sorts text inputs by semantic relevance to a specified query.
11+
12+
Usage::
13+
14+
import os
15+
import json
16+
17+
from predictionguard import PredictionGuard
18+
19+
# Set your Prediction Guard token as an environmental variable.
20+
os.environ["PREDICTIONGUARD_API_KEY"] = "<api key>"
21+
22+
client = PredictionGuard()
23+
24+
response = client.rerank.create(
25+
model="bge-reranker-v2-m3",
26+
query="What is Deep Learning?",
27+
documents=[
28+
"Deep Learning is pizza.",
29+
"Deep Learning is not pizza."
30+
],
31+
return_documents=True
32+
)
33+
34+
print(json.dumps(response, sort_keys=True, indent=4, separators=(",", ": ")))
35+
"""
36+
37+
38+
def __init__(self, api_key, url):
39+
self.api_key = api_key
40+
self.url = url
41+
42+
def create(
43+
self,
44+
model: str,
45+
query: str,
46+
documents: List[str],
47+
return_documents: Optional[bool] = True
48+
) -> Dict[str, Any]:
49+
"""
50+
Creates a rerank request in the Prediction Guard /rerank API.
51+
52+
:param model: The model to use for reranking.
53+
:param query: The query to rank against.
54+
:param documents: The documents to rank.
55+
:param return_documents: Whether to return documents with score.
56+
:return: A dictionary containing the tokens and token metadata.
57+
"""
58+
59+
# Run _create_rerank
60+
choices = self._create_rerank(model, query, documents, return_documents)
61+
return choices
62+
63+
def _create_rerank(self, model, query, documents, return_documents):
64+
"""
65+
Function to rank text.
66+
"""
67+
68+
headers = {
69+
"Content-Type": "application/json",
70+
"Authorization": "Bearer " + self.api_key,
71+
"User-Agent": "Prediction Guard Python Client: " + __version__,
72+
}
73+
74+
payload = {
75+
"model": model,
76+
"query": query,
77+
"documents": documents,
78+
"return_documents": return_documents
79+
}
80+
81+
payload = json.dumps(payload)
82+
83+
response = requests.request(
84+
"POST", self.url + "/rerank", headers=headers, data=payload
85+
)
86+
87+
if response.status_code == 200:
88+
ret = response.json()
89+
return ret
90+
elif response.status_code == 429:
91+
raise ValueError(
92+
"Could not connect to Prediction Guard API. "
93+
"Too many requests, rate limit or quota exceeded."
94+
)
95+
else:
96+
# Check if there is a json body in the response. Read that in,
97+
# print out the error field in the json body, and raise an exception.
98+
err = ""
99+
try:
100+
err = response.json()["error"]
101+
except Exception:
102+
pass
103+
raise ValueError("Could not rank documents. " + err)
104+
105+
def list_models(self):
106+
# Get the list of current models.
107+
headers = {
108+
"Content-Type": "application/json",
109+
"Authorization": "Bearer " + self.api_key,
110+
"User-Agent": "Prediction Guard Python Client: " + __version__
111+
}
112+
113+
response = requests.request("GET", self.url + "/models/rerank", headers=headers)
114+
115+
response_list = []
116+
for model in response.json()["data"]:
117+
response_list.append(model["id"])
118+
119+
return response_list

predictionguard/src/tokenize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, api_key, url):
3636

3737
def create(self, model: str, input: str) -> Dict[str, Any]:
3838
"""
39-
Creates a prompt injection check request in the Prediction Guard /injection API.
39+
Creates a tokenization request in the Prediction Guard /tokenize API.
4040
4141
:param model: The model to use for generating tokens.
4242
:param input: The text to convert into tokens.
@@ -49,7 +49,7 @@ def create(self, model: str, input: str) -> Dict[str, Any]:
4949
"Model %s is not supported by this endpoint." % model
5050
)
5151

52-
# Run _check_injection
52+
# Run _create_tokens
5353
choices = self._create_tokens(model, input)
5454
return choices
5555

predictionguard/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Setting the package version
2-
__version__ = "2.6.0"
2+
__version__ = "2.7.0"

tests/test_rerank.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
3+
from predictionguard import PredictionGuard
4+
5+
6+
def test_rerank_create():
7+
test_client = PredictionGuard()
8+
9+
response = test_client.rerank.create(
10+
model=os.environ["TEST_RERANK_MODEL"],
11+
query="What is Deep Learning?",
12+
documents=[
13+
"Deep Learning is pizza.",
14+
"Deep Learning is not pizza."
15+
],
16+
return_documents=True,
17+
)
18+
19+
assert len(response) > 0
20+
assert type(response["results"][0]["index"]) is int
21+
assert type(response["results"][0]["relevance_score"]) is float
22+
assert type(response["results"][0]["text"]) is str
23+
24+
25+
def test_rerank_list():
26+
test_client = PredictionGuard()
27+
28+
response = test_client.rerank.list_models()
29+
30+
assert len(response) > 0
31+
assert type(response[0]) is str

0 commit comments

Comments
 (0)