Skip to content

Commit d2a63b7

Browse files
committed
adding new models function and updating old functions
1 parent e165ffb commit d2a63b7

File tree

10 files changed

+265
-9
lines changed

10 files changed

+265
-9
lines changed

predictionguard/src/chat.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,24 @@ def stream_generator(url, headers, payload, stream):
272272
else:
273273
return return_dict(self.url, headers, payload)
274274

275-
def list_models(self) -> List[str]:
275+
def list_models(self, type: Optional[str, None]) -> List[str]:
276276
# Get the list of current models.
277277
headers = {
278278
"Content-Type": "application/json",
279279
"Authorization": "Bearer " + self.api_key,
280280
"User-Agent": "Prediction Guard Python Client: " + __version__
281281
}
282282

283-
response = requests.request("GET", self.url + "/chat/completions", headers=headers)
283+
if type is None:
284+
models_path = "/models/completion-chat"
285+
else:
286+
if type != "completion-chat" and type != "vision":
287+
raise ValueError(
288+
"Please enter a valid models type (completion-chat or vision)."
289+
)
290+
else:
291+
model_path = "/models/" + type
292+
293+
response = requests.request("GET", self.url + "/models/completion-chat", headers=headers)
284294

285295
return list(response.json())

predictionguard/src/completions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def list_models(self) -> List[str]:
110110
"User-Agent": "Prediction Guard Python Client: " + __version__,
111111
}
112112

113-
response = requests.request("GET", self.url + "/completions", headers=headers)
113+
response = requests.request("GET", self.url + "/models/completion", headers=headers)
114114

115-
return list(response.json())
115+
response_list = []
116+
for model in response.json()["data"]:
117+
response_list.append(model)
118+
119+
return response_list

predictionguard/src/embeddings.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import base64
55

66
import requests
7-
from typing import Any, Dict, List, Union
7+
from typing import Any, Dict, List, Union, Optional
88
import urllib.request
99
import urllib.parse
1010
import uuid
@@ -174,14 +174,29 @@ def _generate_embeddings(self, model, input, truncate, truncation_direction):
174174
pass
175175
raise ValueError("Could not generate embeddings. " + err)
176176

177-
def list_models(self) -> List[str]:
177+
def list_models(self, type: Optional[str, None] = None) -> List[str]:
178178
# Get the list of current models.
179179
headers = {
180180
"Content-Type": "application/json",
181181
"Authorization": "Bearer " + self.api_key,
182182
"User-Agent": "Prediction Guard Python Client: " + __version__,
183183
}
184184

185-
response = requests.request("GET", self.url + "/embeddings", headers=headers)
185+
if type is None:
186+
models_path = "/models/text-embeddings"
187+
else:
188+
if type != "text-embeddings" and type != "image-embeddings":
189+
raise ValueError(
190+
"Please enter a valid models type "
191+
"(text-embeddings or image-embeddings)."
192+
)
193+
else:
194+
models_path = "/models/" + type
195+
196+
response = requests.request("GET", self.url + models_path, headers=headers)
197+
198+
response_list = []
199+
for model in response.json()["data"]:
200+
response_list.append(model)
186201

187-
return list(response.json())
202+
return response_list

predictionguard/src/models.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import requests
2+
from typing import Any, Dict, Optional
3+
4+
from ..version import __version__
5+
6+
7+
class Models:
8+
"""Models lists all the models available in the Prediction Guard Platform.
9+
10+
Usage::
11+
12+
import os
13+
import json
14+
15+
from predictionguard import PredictionGuard
16+
17+
# Set your Prediction Guard token as an environmental variable.
18+
os.environ["PREDICTIONGUARD_API_KEY"] = "<api key>"
19+
20+
client = PredictionGuard()
21+
22+
response = client.models.list()
23+
24+
print(json.dumps(response, sort_keys=True, indent=4, separators=(",", ": ")))
25+
"""
26+
27+
def __init__(self, api_key, url):
28+
self.api_key = api_key
29+
self.url = url
30+
31+
def list(self, endpoint: Optional[str, None] = None) -> Dict[str, Any]:
32+
"""
33+
Creates a models list request in the Prediction Guard REST API.
34+
35+
:param endpoint: The endpoint of models to list.
36+
:return: A dictionary containing the metadata of all the models.
37+
"""
38+
39+
# Run _check_injection
40+
choices = self._list_models(endpoint)
41+
return choices
42+
43+
def _list_models(self, endpoint):
44+
"""
45+
Function to list available models.
46+
"""
47+
48+
endpoints = [
49+
"completion-chat", "completion", "vision",
50+
"text-embeddings", "image-embeddings", "tokenize"
51+
]
52+
53+
headers = {
54+
"Content-Type": "application/json",
55+
"Authorization": "Bearer " + self.api_key,
56+
"User-Agent": "Prediction Guard Python Client: " + __version__,
57+
}
58+
59+
models_path = "/models"
60+
if endpoint is not None:
61+
if endpoint not in endpoints:
62+
raise ValueError(
63+
"If specifying an endpoint, please use on of the following: "
64+
+ ", ".join(endpoints)
65+
)
66+
else:
67+
models_path += "/" + endpoint
68+
69+
response = requests.request(
70+
"GET", self.url + models_path, headers=headers
71+
)
72+
73+
if response.status_code == 200:
74+
ret = response.json()
75+
return ret
76+
elif response.status_code == 429:
77+
raise ValueError(
78+
"Could not connect to Prediction Guard API. "
79+
"Too many requests, rate limit or quota exceeded."
80+
)
81+
else:
82+
# Check if there is a json body in the response. Read that in,
83+
# print out the error field in the json body, and raise an exception.
84+
err = ""
85+
try:
86+
err = response.json()["error"]
87+
except Exception:
88+
pass
89+
raise ValueError("Could not check for injection. " + err)

predictionguard/src/tokenize.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,19 @@ def _create_tokens(self, model, input):
8989
except Exception:
9090
pass
9191
raise ValueError("Could not generate tokens. " + err)
92+
93+
def list_models(self):
94+
# Get the list of current models.
95+
headers = {
96+
"Content-Type": "application/json",
97+
"Authorization": "Bearer " + self.api_key,
98+
"User-Agent": "Prediction Guard Python Client: " + __version__
99+
}
100+
101+
response = requests.request("GET", self.url + "/models/tokenize", headers=headers)
102+
103+
response_list = []
104+
for model in response.json()["data"]:
105+
response_list.append(model)
106+
107+
return response_list

tests/test_chat.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,15 @@ def test_chat_completions_list_models():
193193
response = test_client.chat.completions.list_models()
194194

195195
assert len(response) > 0
196+
assert type(response[0]) == str
197+
198+
199+
def test_chat_completions_list_models_fail():
200+
test_client = PredictionGuard()
201+
202+
models_error = "Please enter a valid models type (completion-chat or vision)."
203+
204+
with pytest.raises(ValueError, match=models_error):
205+
test_client.chat.completions.list_models(
206+
type="fail"
207+
)

tests/test_completions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ def test_completions_list_models():
3232
response = test_client.completions.list_models()
3333

3434
assert len(response) > 0
35+
assert type(response[0]) == str

tests/test_embeddings.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import re
32
import base64
43

54
import pytest
@@ -215,3 +214,15 @@ def test_embeddings_list_models():
215214
response = test_client.embeddings.list_models()
216215

217216
assert len(response) > 0
217+
assert type(response[0]) is str
218+
219+
220+
def test_embeddings_list_models_fail():
221+
test_client = PredictionGuard()
222+
223+
models_error = ""
224+
225+
with pytest.raises(ValueError, match=models_error):
226+
test_client.embeddings.list_models(
227+
type="fail"
228+
)

tests/test_models.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from jedi.plugins import pytest
2+
from uaclient.api.u.pro.security.fix.cve.plan.v1 import endpoint
3+
4+
from predictionguard import PredictionGuard
5+
6+
7+
def test_models_list():
8+
test_client = PredictionGuard()
9+
10+
response = test_client.models.list()
11+
12+
assert len(response["data"]) > 0
13+
assert type(response["data"][0]["id"]) is str
14+
15+
16+
def test_models_list_completion_chat():
17+
test_client = PredictionGuard()
18+
19+
response = test_client.models.list(
20+
endpoint="completion-chat"
21+
)
22+
23+
assert len(response["data"]) > 0
24+
assert type(response["data"][0]["id"]) is str
25+
26+
27+
def test_models_list_completion():
28+
test_client = PredictionGuard()
29+
30+
response = test_client.models.list(
31+
endpoint="completion"
32+
)
33+
34+
assert len(response["data"]) > 0
35+
assert type(response["data"][0]["id"]) is str
36+
37+
38+
def test_models_list_vision():
39+
test_client = PredictionGuard()
40+
41+
response = test_client.models.list(
42+
endpoint="vision"
43+
)
44+
45+
assert len(response["data"]) > 0
46+
assert type(response["data"][0]["id"]) is str
47+
48+
49+
def test_models_list_text_embeddings():
50+
test_client = PredictionGuard()
51+
52+
response = test_client.models.list(
53+
endpoint="text-embeddings"
54+
)
55+
56+
assert len(response["data"]) > 0
57+
assert type(response["data"][0]["id"]) is str
58+
59+
60+
def test_models_list_image_embeddings():
61+
test_client = PredictionGuard()
62+
63+
response = test_client.models.list(
64+
endpoint="image-embeddings"
65+
)
66+
67+
assert len(response["data"]) > 0
68+
assert type(response["data"][0]["id"]) is str
69+
70+
71+
def test_models_list_tokenize():
72+
test_client = PredictionGuard()
73+
74+
response = test_client.models.list(
75+
endpoint="tokenize"
76+
)
77+
78+
assert len(response["data"]) > 0
79+
assert type(response["data"][0]["id"]) is str
80+
81+
82+
def test_models_list_fail():
83+
test_client = PredictionGuard()
84+
85+
models_error = ""
86+
87+
with pytest.raises(ValueError, match=models_error):
88+
test_client.models.list(
89+
endpoint="fail"
90+
)

tests/test_tokenize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@ def test_tokenize_create():
1414
assert len(response) > 0
1515
assert type(response["tokens"][0]["id"]) is int
1616

17+
18+
def test_tokenize_list():
19+
test_client = PredictionGuard()
20+
21+
response = test_client.tokenize.list_models()
22+
23+
assert len(response) > 0
24+
assert type(response[0]) == str

0 commit comments

Comments
 (0)