Skip to content

Commit 3998fc6

Browse files
authored
Merge pull request #34 from predictionguard/jacob/models-endpoint
Adding `/models` functionality and updating old `list_models()` functions to work with API changes
2 parents e165ffb + 4736b43 commit 3998fc6

File tree

11 files changed

+233
-15
lines changed

11 files changed

+233
-15
lines changed

predictionguard/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from .src.toxicity import Toxicity
1313
from .src.pii import Pii
1414
from .src.injection import Injection
15+
from .src.models import Models
1516
from .version import __version__
1617

1718
__all__ = [
1819
"PredictionGuard", "Chat", "Completions", "Embeddings", "Tokenize",
19-
"Translate", "Factuality", "Toxicity", "Pii", "Injection"
20+
"Translate", "Factuality", "Toxicity", "Pii", "Injection", "Models"
2021
]
2122

2223
class PredictionGuard:
@@ -80,6 +81,9 @@ def __init__(
8081
self.tokenize: Tokenize = Tokenize(self.api_key, self.url)
8182
"""Tokenize generates tokens for input text."""
8283

84+
self.models: Models = Models(self.api_key, self.url)
85+
"""Models lists all of the models available in the Prediction Guard API."""
86+
8387
def _connect_client(self) -> None:
8488

8589
# Prepare the proper headers.

predictionguard/src/chat.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any, Dict, List, Optional, Union
88
import urllib.request
99
import urllib.parse
10-
from warnings import warn
1110
import uuid
1211

1312
from ..version import __version__
@@ -272,14 +271,25 @@ def stream_generator(url, headers, payload, stream):
272271
else:
273272
return return_dict(self.url, headers, payload)
274273

275-
def list_models(self) -> List[str]:
274+
def list_models(self, capability: Optional[str] = "chat-completion") -> List[str]:
276275
# Get the list of current models.
277276
headers = {
278277
"Content-Type": "application/json",
279278
"Authorization": "Bearer " + self.api_key,
280279
"User-Agent": "Prediction Guard Python Client: " + __version__
281280
}
282281

283-
response = requests.request("GET", self.url + "/chat/completions", headers=headers)
282+
if capability != "chat-completion" and capability != "chat-with-image":
283+
raise ValueError(
284+
"Please enter a valid model type (chat-completion or chat-with-image)."
285+
)
286+
else:
287+
model_path = "/models/" + capability
288+
289+
response = requests.request("GET", self.url + model_path, headers=headers)
290+
291+
response_list = []
292+
for model in response.json()["data"]:
293+
response_list.append(model["id"])
284294

285-
return list(response.json())
295+
return response_list

predictionguard/src/completions.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import requests
44
from typing import Any, Dict, List, Optional, Union
5-
from warnings import warn
65

76
from ..version import __version__
87

@@ -110,6 +109,10 @@ def list_models(self) -> List[str]:
110109
"User-Agent": "Prediction Guard Python Client: " + __version__,
111110
}
112111

113-
response = requests.request("GET", self.url + "/completions", headers=headers)
112+
response = requests.request("GET", self.url + "/models/completion", headers=headers)
114113

115-
return list(response.json())
114+
response_list = []
115+
for model in response.json()["data"]:
116+
response_list.append(model["id"])
117+
118+
return response_list

predictionguard/src/embeddings.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import base64
55

66
import requests
7-
from typing import Any, Dict, List, Union
7+
from typing import Any, Dict, List, Union, Optional
88
import urllib.request
99
import urllib.parse
1010
import uuid
@@ -174,14 +174,26 @@ def _generate_embeddings(self, model, input, truncate, truncation_direction):
174174
pass
175175
raise ValueError("Could not generate embeddings. " + err)
176176

177-
def list_models(self) -> List[str]:
177+
def list_models(self, capability: Optional[str] = "embedding") -> List[str]:
178178
# Get the list of current models.
179179
headers = {
180180
"Content-Type": "application/json",
181181
"Authorization": "Bearer " + self.api_key,
182182
"User-Agent": "Prediction Guard Python Client: " + __version__,
183183
}
184184

185-
response = requests.request("GET", self.url + "/embeddings", headers=headers)
185+
if capability != "embedding" and capability != "embedding-with-image":
186+
raise ValueError(
187+
"Please enter a valid models type "
188+
"(embedding or embedding-with-image)."
189+
)
190+
else:
191+
model_path = "/models/" + capability
192+
193+
response = requests.request("GET", self.url + model_path, headers=headers)
194+
195+
response_list = []
196+
for model in response.json()["data"]:
197+
response_list.append(model["id"])
186198

187-
return list(response.json())
199+
return response_list

predictionguard/src/models.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import requests
2+
from typing import Any, Dict, Optional
3+
4+
from ..version import __version__
5+
6+
7+
class Models:
8+
"""Models lists all the models available in the Prediction Guard Platform.
9+
10+
Usage::
11+
12+
import os
13+
import json
14+
15+
from predictionguard import PredictionGuard
16+
17+
# Set your Prediction Guard token as an environmental variable.
18+
os.environ["PREDICTIONGUARD_API_KEY"] = "<api key>"
19+
20+
client = PredictionGuard()
21+
22+
response = client.models.list()
23+
24+
print(json.dumps(response, sort_keys=True, indent=4, separators=(",", ": ")))
25+
"""
26+
27+
def __init__(self, api_key, url):
28+
self.api_key = api_key
29+
self.url = url
30+
31+
def list(self, capability: Optional[str] = "") -> Dict[str, Any]:
32+
"""
33+
Creates a models list request in the Prediction Guard REST API.
34+
35+
:param capability: The capability of models to list.
36+
:return: A dictionary containing the metadata of all the models.
37+
"""
38+
39+
# Run _check_injection
40+
choices = self._list_models(capability)
41+
return choices
42+
43+
def _list_models(self, capability):
44+
"""
45+
Function to list available models.
46+
"""
47+
48+
capabilities = [
49+
"chat-completion", "chat-with-image", "completion",
50+
"embedding", "embedding-with-image", "tokenize"
51+
]
52+
53+
headers = {
54+
"Content-Type": "application/json",
55+
"Authorization": "Bearer " + self.api_key,
56+
"User-Agent": "Prediction Guard Python Client: " + __version__,
57+
}
58+
59+
models_path = "/models"
60+
if capability != "":
61+
if capability not in capabilities:
62+
raise ValueError(
63+
"If specifying a capability, please use one of the following: "
64+
+ ", ".join(capabilities)
65+
)
66+
else:
67+
models_path += "/" + capability
68+
69+
response = requests.request(
70+
"GET", self.url + models_path, headers=headers
71+
)
72+
73+
if response.status_code == 200:
74+
ret = response.json()
75+
return ret
76+
elif response.status_code == 429:
77+
raise ValueError(
78+
"Could not connect to Prediction Guard API. "
79+
"Too many requests, rate limit or quota exceeded."
80+
)
81+
else:
82+
# Check if there is a json body in the response. Read that in,
83+
# print out the error field in the json body, and raise an exception.
84+
err = ""
85+
try:
86+
err = response.json()["error"]
87+
except Exception:
88+
pass
89+
raise ValueError("Could not check for injection. " + err)

predictionguard/src/tokenize.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,19 @@ def _create_tokens(self, model, input):
8989
except Exception:
9090
pass
9191
raise ValueError("Could not generate tokens. " + err)
92+
93+
def list_models(self):
94+
# Get the list of current models.
95+
headers = {
96+
"Content-Type": "application/json",
97+
"Authorization": "Bearer " + self.api_key,
98+
"User-Agent": "Prediction Guard Python Client: " + __version__
99+
}
100+
101+
response = requests.request("GET", self.url + "/models/tokenize", headers=headers)
102+
103+
response_list = []
104+
for model in response.json()["data"]:
105+
response_list.append(model["id"])
106+
107+
return response_list

tests/test_chat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,4 @@ def test_chat_completions_list_models():
193193
response = test_client.chat.completions.list_models()
194194

195195
assert len(response) > 0
196+
assert type(response[0]) is str

tests/test_completions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ def test_completions_list_models():
3232
response = test_client.completions.list_models()
3333

3434
assert len(response) > 0
35+
assert type(response[0]) is str

tests/test_embeddings.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import os
2-
import re
32
import base64
43

5-
import pytest
6-
74
from predictionguard import PredictionGuard
85

96

@@ -215,3 +212,4 @@ def test_embeddings_list_models():
215212
response = test_client.embeddings.list_models()
216213

217214
assert len(response) > 0
215+
assert type(response[0]) is str

tests/test_models.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from predictionguard import PredictionGuard
2+
3+
4+
def test_models_list():
5+
test_client = PredictionGuard()
6+
7+
response = test_client.models.list()
8+
9+
assert len(response["data"]) > 0
10+
assert type(response["data"][0]["id"]) is str
11+
12+
13+
def test_models_list_chat_completion():
14+
test_client = PredictionGuard()
15+
16+
response = test_client.models.list(
17+
capability="chat-completion"
18+
)
19+
20+
assert len(response["data"]) > 0
21+
assert type(response["data"][0]["id"]) is str
22+
23+
24+
def test_models_list_chat_with_image():
25+
test_client = PredictionGuard()
26+
27+
response = test_client.models.list(
28+
capability="chat-with-image"
29+
)
30+
31+
assert len(response["data"]) > 0
32+
assert type(response["data"][0]["id"]) is str
33+
34+
35+
def test_models_list_completion():
36+
test_client = PredictionGuard()
37+
38+
response = test_client.models.list(
39+
capability="completion"
40+
)
41+
42+
assert len(response["data"]) > 0
43+
assert type(response["data"][0]["id"]) is str
44+
45+
46+
def test_models_list_embedding():
47+
test_client = PredictionGuard()
48+
49+
response = test_client.models.list(
50+
capability="embedding"
51+
)
52+
53+
assert len(response["data"]) > 0
54+
assert type(response["data"][0]["id"]) is str
55+
56+
57+
def test_models_list_embedding_with_image():
58+
test_client = PredictionGuard()
59+
60+
response = test_client.models.list(
61+
capability="embedding-with-image"
62+
)
63+
64+
assert len(response["data"]) > 0
65+
assert type(response["data"][0]["id"]) is str
66+
67+
68+
def test_models_list_tokenize():
69+
test_client = PredictionGuard()
70+
71+
response = test_client.models.list(
72+
capability="tokenize"
73+
)
74+
75+
assert len(response["data"]) > 0
76+
assert type(response["data"][0]["id"]) is str

0 commit comments

Comments
 (0)