Skip to content

Commit 4736b43

Browse files
committed
adding new models endpoint function and updating old ones
1 parent d2a63b7 commit 4736b43

File tree

11 files changed

+55
-93
lines changed

11 files changed

+55
-93
lines changed

predictionguard/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from .src.toxicity import Toxicity
1313
from .src.pii import Pii
1414
from .src.injection import Injection
15+
from .src.models import Models
1516
from .version import __version__
1617

1718
__all__ = [
1819
"PredictionGuard", "Chat", "Completions", "Embeddings", "Tokenize",
19-
"Translate", "Factuality", "Toxicity", "Pii", "Injection"
20+
"Translate", "Factuality", "Toxicity", "Pii", "Injection", "Models"
2021
]
2122

2223
class PredictionGuard:
@@ -80,6 +81,9 @@ def __init__(
8081
self.tokenize: Tokenize = Tokenize(self.api_key, self.url)
8182
"""Tokenize generates tokens for input text."""
8283

84+
self.models: Models = Models(self.api_key, self.url)
85+
"""Models lists all of the models available in the Prediction Guard API."""
86+
8387
def _connect_client(self) -> None:
8488

8589
# Prepare the proper headers.

predictionguard/src/chat.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any, Dict, List, Optional, Union
88
import urllib.request
99
import urllib.parse
10-
from warnings import warn
1110
import uuid
1211

1312
from ..version import __version__
@@ -272,24 +271,25 @@ def stream_generator(url, headers, payload, stream):
272271
else:
273272
return return_dict(self.url, headers, payload)
274273

275-
def list_models(self, type: Optional[str, None]) -> List[str]:
274+
def list_models(self, capability: Optional[str] = "chat-completion") -> List[str]:
276275
# Get the list of current models.
277276
headers = {
278277
"Content-Type": "application/json",
279278
"Authorization": "Bearer " + self.api_key,
280279
"User-Agent": "Prediction Guard Python Client: " + __version__
281280
}
282281

283-
if type is None:
284-
models_path = "/models/completion-chat"
282+
if capability != "chat-completion" and capability != "chat-with-image":
283+
raise ValueError(
284+
"Please enter a valid model type (chat-completion or chat-with-image)."
285+
)
285286
else:
286-
if type != "completion-chat" and type != "vision":
287-
raise ValueError(
288-
"Please enter a valid models type (completion-chat or vision)."
289-
)
290-
else:
291-
model_path = "/models/" + type
287+
model_path = "/models/" + capability
288+
289+
response = requests.request("GET", self.url + model_path, headers=headers)
292290

293-
response = requests.request("GET", self.url + "/models/completion-chat", headers=headers)
291+
response_list = []
292+
for model in response.json()["data"]:
293+
response_list.append(model["id"])
294294

295-
return list(response.json())
295+
return response_list

predictionguard/src/completions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import requests
44
from typing import Any, Dict, List, Optional, Union
5-
from warnings import warn
65

76
from ..version import __version__
87

@@ -114,6 +113,6 @@ def list_models(self) -> List[str]:
114113

115114
response_list = []
116115
for model in response.json()["data"]:
117-
response_list.append(model)
116+
response_list.append(model["id"])
118117

119118
return response_list

predictionguard/src/embeddings.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,29 +174,26 @@ def _generate_embeddings(self, model, input, truncate, truncation_direction):
174174
pass
175175
raise ValueError("Could not generate embeddings. " + err)
176176

177-
def list_models(self, type: Optional[str, None] = None) -> List[str]:
177+
def list_models(self, capability: Optional[str] = "embedding") -> List[str]:
178178
# Get the list of current models.
179179
headers = {
180180
"Content-Type": "application/json",
181181
"Authorization": "Bearer " + self.api_key,
182182
"User-Agent": "Prediction Guard Python Client: " + __version__,
183183
}
184184

185-
if type is None:
186-
models_path = "/models/text-embeddings"
185+
if capability != "embedding" and capability != "embedding-with-image":
186+
raise ValueError(
187+
"Please enter a valid models type "
188+
"(embedding or embedding-with-image)."
189+
)
187190
else:
188-
if type != "text-embeddings" and type != "image-embeddings":
189-
raise ValueError(
190-
"Please enter a valid models type "
191-
"(text-embeddings or image-embeddings)."
192-
)
193-
else:
194-
models_path = "/models/" + type
191+
model_path = "/models/" + capability
195192

196-
response = requests.request("GET", self.url + models_path, headers=headers)
193+
response = requests.request("GET", self.url + model_path, headers=headers)
197194

198195
response_list = []
199196
for model in response.json()["data"]:
200-
response_list.append(model)
197+
response_list.append(model["id"])
201198

202199
return response_list

predictionguard/src/models.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,26 @@ def __init__(self, api_key, url):
2828
self.api_key = api_key
2929
self.url = url
3030

31-
def list(self, endpoint: Optional[str, None] = None) -> Dict[str, Any]:
31+
def list(self, capability: Optional[str] = "") -> Dict[str, Any]:
3232
"""
3333
Creates a models list request in the Prediction Guard REST API.
3434
35-
:param endpoint: The endpoint of models to list.
35+
:param capability: The capability of models to list.
3636
:return: A dictionary containing the metadata of all the models.
3737
"""
3838

3939
# Run _check_injection
40-
choices = self._list_models(endpoint)
40+
choices = self._list_models(capability)
4141
return choices
4242

43-
def _list_models(self, endpoint):
43+
def _list_models(self, capability):
4444
"""
4545
Function to list available models.
4646
"""
4747

48-
endpoints = [
49-
"completion-chat", "completion", "vision",
50-
"text-embeddings", "image-embeddings", "tokenize"
48+
capabilities = [
49+
"chat-completion", "chat-with-image", "completion",
50+
"embedding", "embedding-with-image", "tokenize"
5151
]
5252

5353
headers = {
@@ -57,14 +57,14 @@ def _list_models(self, endpoint):
5757
}
5858

5959
models_path = "/models"
60-
if endpoint is not None:
61-
if endpoint not in endpoints:
60+
if capability != "":
61+
if capability not in capabilities:
6262
raise ValueError(
63-
"If specifying an endpoint, please use on of the following: "
64-
+ ", ".join(endpoints)
63+
"If specifying a capability, please use one of the following: "
64+
+ ", ".join(capabilities)
6565
)
6666
else:
67-
models_path += "/" + endpoint
67+
models_path += "/" + capability
6868

6969
response = requests.request(
7070
"GET", self.url + models_path, headers=headers

predictionguard/src/tokenize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,6 @@ def list_models(self):
102102

103103
response_list = []
104104
for model in response.json()["data"]:
105-
response_list.append(model)
105+
response_list.append(model["id"])
106106

107107
return response_list

tests/test_chat.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,4 @@ def test_chat_completions_list_models():
193193
response = test_client.chat.completions.list_models()
194194

195195
assert len(response) > 0
196-
assert type(response[0]) == str
197-
198-
199-
def test_chat_completions_list_models_fail():
200-
test_client = PredictionGuard()
201-
202-
models_error = "Please enter a valid models type (completion-chat or vision)."
203-
204-
with pytest.raises(ValueError, match=models_error):
205-
test_client.chat.completions.list_models(
206-
type="fail"
207-
)
196+
assert type(response[0]) is str

tests/test_completions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ def test_completions_list_models():
3232
response = test_client.completions.list_models()
3333

3434
assert len(response) > 0
35-
assert type(response[0]) == str
35+
assert type(response[0]) is str

tests/test_embeddings.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import os
22
import base64
33

4-
import pytest
5-
64
from predictionguard import PredictionGuard
75

86

@@ -214,15 +212,4 @@ def test_embeddings_list_models():
214212
response = test_client.embeddings.list_models()
215213

216214
assert len(response) > 0
217-
assert type(response[0]) is str
218-
219-
220-
def test_embeddings_list_models_fail():
221-
test_client = PredictionGuard()
222-
223-
models_error = ""
224-
225-
with pytest.raises(ValueError, match=models_error):
226-
test_client.embeddings.list_models(
227-
type="fail"
228-
)
215+
assert type(response[0]) is str

tests/test_models.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from jedi.plugins import pytest
2-
from uaclient.api.u.pro.security.fix.cve.plan.v1 import endpoint
3-
41
from predictionguard import PredictionGuard
52

63

@@ -13,55 +10,55 @@ def test_models_list():
1310
assert type(response["data"][0]["id"]) is str
1411

1512

16-
def test_models_list_completion_chat():
13+
def test_models_list_chat_completion():
1714
test_client = PredictionGuard()
1815

1916
response = test_client.models.list(
20-
endpoint="completion-chat"
17+
capability="chat-completion"
2118
)
2219

2320
assert len(response["data"]) > 0
2421
assert type(response["data"][0]["id"]) is str
2522

2623

27-
def test_models_list_completion():
24+
def test_models_list_chat_with_image():
2825
test_client = PredictionGuard()
2926

3027
response = test_client.models.list(
31-
endpoint="completion"
28+
capability="chat-with-image"
3229
)
3330

3431
assert len(response["data"]) > 0
3532
assert type(response["data"][0]["id"]) is str
3633

3734

38-
def test_models_list_vision():
35+
def test_models_list_completion():
3936
test_client = PredictionGuard()
4037

4138
response = test_client.models.list(
42-
endpoint="vision"
39+
capability="completion"
4340
)
4441

4542
assert len(response["data"]) > 0
4643
assert type(response["data"][0]["id"]) is str
4744

4845

49-
def test_models_list_text_embeddings():
46+
def test_models_list_embedding():
5047
test_client = PredictionGuard()
5148

5249
response = test_client.models.list(
53-
endpoint="text-embeddings"
50+
capability="embedding"
5451
)
5552

5653
assert len(response["data"]) > 0
5754
assert type(response["data"][0]["id"]) is str
5855

5956

60-
def test_models_list_image_embeddings():
57+
def test_models_list_embedding_with_image():
6158
test_client = PredictionGuard()
6259

6360
response = test_client.models.list(
64-
endpoint="image-embeddings"
61+
capability="embedding-with-image"
6562
)
6663

6764
assert len(response["data"]) > 0
@@ -72,19 +69,8 @@ def test_models_list_tokenize():
7269
test_client = PredictionGuard()
7370

7471
response = test_client.models.list(
75-
endpoint="tokenize"
72+
capability="tokenize"
7673
)
7774

7875
assert len(response["data"]) > 0
7976
assert type(response["data"][0]["id"]) is str
80-
81-
82-
def test_models_list_fail():
83-
test_client = PredictionGuard()
84-
85-
models_error = ""
86-
87-
with pytest.raises(ValueError, match=models_error):
88-
test_client.models.list(
89-
endpoint="fail"
90-
)

0 commit comments

Comments
 (0)