Skip to content

Commit 2f13742

Browse files
committed
all endpoints working
1 parent 22e3bda commit 2f13742

File tree

4 files changed

+119
-11
lines changed

4 files changed

+119
-11
lines changed

predictionguard/src/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def create(
9595
str, Dict[
9696
str, Dict[str, str]
9797
]
98-
]] = None,
98+
]] = "none",
9999
tools: Optional[List[Dict[str, Union[str, Dict[str, str]]]]] = None,
100100
top_p: Optional[float] = 0.99,
101101
top_k: Optional[float] = 50,

predictionguard/src/completions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def create(
8787

8888
return choices
8989

90-
# TODO: Fix stream response engine
9190
def _generate_completion(
9291
self,
9392
model,
@@ -153,7 +152,7 @@ def stream_generator(url, headers, payload, stream):
153152
pass
154153
else:
155154
try:
156-
dict_return["data"]["choices"][0]["delta"]["content"]
155+
dict_return["data"]["choices"][0]["text"]
157156
except KeyError:
158157
pass
159158
else:

predictionguard/src/translate.py

Lines changed: 105 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,122 @@
1+
import json
2+
3+
import requests
14
from typing import Any, Dict, Optional
25

6+
from ..version import __version__
7+
38

49
class Translate:
5-
"""No longer supported.
10+
# UNCOMMENT WHEN DEPRECATED
11+
# """No longer supported.
12+
# """
13+
#
14+
# def __init__(self, api_key, url):
15+
# self.api_key = api_key
16+
# self.url = url
17+
#
18+
# def create(
19+
# self,
20+
# text: Optional[str],
21+
# source_lang: Optional[str],
22+
# target_lang: Optional[str],
23+
# use_third_party_engine: Optional[bool] = False
24+
# ) -> Dict[str, Any]:
25+
# """
26+
# No longer supported
27+
# """
28+
#
29+
# raise ValueError(
30+
# "The translate functionality is no longer supported."
31+
# )
32+
"""Translate converts text from one language to another.
33+
34+
Usage::
35+
36+
from predictionguard import PredictionGuard
37+
38+
# Set your Prediction Guard token as an environmental variable.
39+
os.environ["PREDICTIONGUARD_API_KEY"] = "<api key>"
40+
41+
client = PredictionGuard()
42+
43+
response = client.translate.create(
44+
text="The sky is blue.",
45+
source_lang="eng",
46+
target_lang="fra",
47+
use_third_party_engine=True
48+
)
49+
50+
print(json.dumps(response, sort_keys=True, indent=4, separators=(",", ": ")))
651
"""
752

53+
# REMOVE BELOW HERE FOR DEPRECATION
854
def __init__(self, api_key, url):
955
self.api_key = api_key
1056
self.url = url
1157

1258
def create(
1359
self,
14-
text: Optional[str],
15-
source_lang: Optional[str],
16-
target_lang: Optional[str],
60+
text: str,
61+
source_lang: str,
62+
target_lang: str,
1763
use_third_party_engine: Optional[bool] = False
1864
) -> Dict[str, Any]:
1965
"""
20-
No longer supported
66+
Creates a translate request to the Prediction Guard /translate API.
67+
68+
:param text: The text to be translated.
69+
:param source_lang: The language the text is currently in.
70+
:param target_lang: The language the text will be translated to.
71+
:param use_third_party_engine: A boolean for enabling translations with third party APIs.
72+
:result: A dictionary containing the translate response.
73+
"""
74+
75+
# Create a list of tuples, each containing all the parameters for
76+
# a call to _generate_translation
77+
args = (text, source_lang, target_lang, use_third_party_engine)
78+
79+
# Run _generate_translation
80+
choices = self._generate_translation(*args)
81+
return choices
82+
83+
def _generate_translation(self, text, source_lang, target_lang, use_third_party_engine):
84+
"""
85+
Function to generate a translation response.
2186
"""
2287

23-
raise ValueError(
24-
"The translate functionality is no longer supported."
25-
)
88+
headers = {
89+
"Content-Type": "application/json",
90+
"Authorization": "Bearer " + self.api_key,
91+
"User-Agent": "Prediction Guard Python Client: " + __version__,
92+
}
93+
94+
payload_dict = {
95+
"text": text,
96+
"source_lang": source_lang,
97+
"target_lang": target_lang,
98+
"use_third_party_engine": use_third_party_engine
99+
}
100+
payload = json.dumps(payload_dict)
101+
response = requests.request(
102+
"POST", self.url + "/translate", headers=headers, data=payload
103+
)
104+
105+
# If the request was successful, print the proxies.
106+
if response.status_code == 200:
107+
ret = response.json()
108+
return ret
109+
elif response.status_code == 429:
110+
raise ValueError(
111+
"Could not connect to Prediction Guard API. "
112+
"Too many requests, rate limit or quota exceeded."
113+
)
114+
else:
115+
# Check if there is a json body in the response. Read that in,
116+
# print out the error field in the json body, and raise an exception.
117+
err = ""
118+
try:
119+
err = response.json()["error"]
120+
except Exception:
121+
pass
122+
raise ValueError("Could not make translation. " + err)

tests/test_translate.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from predictionguard import PredictionGuard
2+
3+
4+
def test_translate_create():
5+
test_client = PredictionGuard()
6+
7+
response = test_client.translate.create(
8+
text="The sky is blue", source_lang="eng", target_lang="fra"
9+
)
10+
11+
assert type(response["best_score"]) is float
12+
assert len(response["best_translation"])

0 commit comments

Comments
 (0)