Skip to content

Commit 47f2078

Browse files
authored
adding support for timeouts (#44)
1 parent 21abbae commit 47f2078

File tree

16 files changed

+105
-65
lines changed

16 files changed

+105
-65
lines changed

predictionguard/client.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

33
import requests
4-
from typing import Optional
4+
from typing import Optional, Union
55

66
from .src.audio import Audio
77
from .src.chat import Chat
@@ -30,11 +30,15 @@ class PredictionGuard:
3030
"""PredictionGuard provides access the Prediction Guard API."""
3131

3232
def __init__(
33-
self, api_key: Optional[str] = None, url: Optional[str] = None
33+
self,
34+
api_key: Optional[str] = None,
35+
url: Optional[str] = None,
36+
timeout: Optional[Union[int, float]] = None
3437
) -> None:
3538
"""
3639
:param api_key: api_key represents PG api key.
3740
:param url: url represents the transport and domain:port
41+
:param timeout: request timeout in seconds.
3842
"""
3943

4044
# Get the access api_key.
@@ -56,50 +60,67 @@ def __init__(
5660
url = "https://api.predictionguard.com"
5761
self.url = url
5862

63+
if not timeout:
64+
timeout = os.environ.get("TIMEOUT")
65+
if not timeout:
66+
timeout = None
67+
if timeout:
68+
try:
69+
timeout = float(timeout)
70+
except ValueError:
71+
raise ValueError(
72+
"Timeout must be of type integer or float, not %s." % (type(timeout).__name__,)
73+
)
74+
except TypeError:
75+
raise TypeError(
76+
"Timeout should be of type integer or float, not %s." % (type(timeout).__name__,)
77+
)
78+
self.timeout = timeout
79+
5980
# Connect to Prediction Guard and set the access api_key.
6081
self._connect_client()
6182

6283
# Pass Prediction Guard class variables to inner classes
63-
self.chat: Chat = Chat(self.api_key, self.url)
84+
self.chat: Chat = Chat(self.api_key, self.url, self.timeout)
6485
"""Chat generates chat completions based on a conversation history"""
6586

66-
self.completions: Completions = Completions(self.api_key, self.url)
87+
self.completions: Completions = Completions(self.api_key, self.url, self.timeout)
6788
"""Completions generates text completions based on the provided input"""
6889

69-
self.embeddings: Embeddings = Embeddings(self.api_key, self.url)
90+
self.embeddings: Embeddings = Embeddings(self.api_key, self.url, self.timeout)
7091
"""Embedding generates chat completions based on a conversation history."""
7192

72-
self.audio: Audio = Audio(self.api_key, self.url)
93+
self.audio: Audio = Audio(self.api_key, self.url, self.timeout)
7394
"""Audio allows for the transcription of audio files."""
7495

75-
self.documents: Documents = Documents(self.api_key, self.url)
96+
self.documents: Documents = Documents(self.api_key, self.url, self.timeout)
7697
"""Documents allows you to extract text from various document file types."""
7798

78-
self.rerank: Rerank = Rerank(self.api_key, self.url)
99+
self.rerank: Rerank = Rerank(self.api_key, self.url, self.timeout)
79100
"""Rerank sorts text inputs by semantic relevance to a specified query."""
80101

81-
self.translate: Translate = Translate(self.api_key, self.url)
102+
self.translate: Translate = Translate(self.api_key, self.url, self.timeout)
82103
"""Translate converts text from one language to another."""
83104

84-
self.factuality: Factuality = Factuality(self.api_key, self.url)
105+
self.factuality: Factuality = Factuality(self.api_key, self.url, self.timeout)
85106
"""Factuality checks the factuality of a given text compared to a reference."""
86107

87-
self.toxicity: Toxicity = Toxicity(self.api_key, self.url)
108+
self.toxicity: Toxicity = Toxicity(self.api_key, self.url, self.timeout)
88109
"""Toxicity checks the toxicity of a given text."""
89110

90-
self.pii: Pii = Pii(self.api_key, self.url)
111+
self.pii: Pii = Pii(self.api_key, self.url, self.timeout)
91112
"""Pii replaces personal information such as names, SSNs, and emails in a given text."""
92113

93-
self.injection: Injection = Injection(self.api_key, self.url)
114+
self.injection: Injection = Injection(self.api_key, self.url, self.timeout)
94115
"""Injection detects potential prompt injection attacks in a given prompt."""
95116

96-
self.tokenize: Tokenize = Tokenize(self.api_key, self.url)
117+
self.tokenize: Tokenize = Tokenize(self.api_key, self.url, self.timeout)
97118
"""Tokenize generates tokens for input text."""
98119

99-
self.detokenize: Detokenize = Detokenize(self.api_key, self.url)
120+
self.detokenize: Detokenize = Detokenize(self.api_key, self.url, self.timeout)
100121
"""Detokenizes generates text for input tokens."""
101122

102-
self.models: Models = Models(self.api_key, self.url)
123+
self.models: Models = Models(self.api_key, self.url, self.timeout)
103124
"""Models lists all of the models available in the Prediction Guard API."""
104125

105126
def _connect_client(self) -> None:
@@ -112,7 +133,7 @@ def _connect_client(self) -> None:
112133
}
113134

114135
# Try listing models to make sure we can connect.
115-
response = requests.request("GET", self.url + "/completions", headers=headers)
136+
response = requests.request("GET", self.url + "/completions", headers=headers, timeout=self.timeout)
116137

117138
# If the connection was unsuccessful, raise an exception.
118139
if response.status_code == 200:

predictionguard/src/audio.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,18 @@ class Audio:
3939
))
4040
"""
4141

42-
def __init__(self, api_key, url):
42+
def __init__(self, api_key, url, timeout):
4343
self.api_key = api_key
4444
self.url = url
45+
self.timeout = timeout
4546

46-
self.transcriptions: AudioTranscriptions = AudioTranscriptions(self.api_key, self.url)
47+
self.transcriptions: AudioTranscriptions = AudioTranscriptions(self.api_key, self.url, self.timeout)
4748

4849
class AudioTranscriptions:
49-
def __init__(self, api_key, url):
50+
def __init__(self, api_key, url, timeout):
5051
self.api_key = api_key
5152
self.url = url
53+
self.timeout = timeout
5254

5355
def create(
5456
self,
@@ -164,7 +166,7 @@ def _transcribe_audio(
164166
}
165167

166168
response = requests.request(
167-
"POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data
169+
"POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data, timeout=self.timeout
168170
)
169171

170172
# If the request was successful, print the proxies.

predictionguard/src/chat.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,19 @@ class Chat:
6666
))
6767
"""
6868

69-
def __init__(self, api_key, url):
69+
def __init__(self, api_key, url, timeout):
7070
self.api_key = api_key
7171
self.url = url
72+
self.timeout = timeout
7273

73-
self.completions: ChatCompletions = ChatCompletions(self.api_key, self.url)
74+
self.completions: ChatCompletions = ChatCompletions(self.api_key, self.url, self.timeout)
7475

7576

7677
class ChatCompletions:
77-
def __init__(self, api_key, url):
78+
def __init__(self, api_key, url, timeout):
7879
self.api_key = api_key
7980
self.url = url
81+
self.timeout = timeout
8082

8183
def create(
8284
self,
@@ -192,9 +194,9 @@ def _generate_chat(
192194
Function to generate a single chat response.
193195
"""
194196

195-
def return_dict(url, headers, payload):
197+
def return_dict(url, headers, payload, timeout):
196198
response = requests.request(
197-
"POST", url + "/chat/completions", headers=headers, data=payload
199+
"POST", url + "/chat/completions", headers=headers, data=payload, timeout=timeout
198200
)
199201
# If the request was successful, print the proxies.
200202
if response.status_code == 200:
@@ -215,12 +217,13 @@ def return_dict(url, headers, payload):
215217
pass
216218
raise ValueError("Could not make prediction. " + err)
217219

218-
def stream_generator(url, headers, payload, stream):
220+
def stream_generator(url, headers, payload, stream, timeout):
219221
with requests.post(
220222
url + "/chat/completions",
221223
headers=headers,
222224
data=payload,
223225
stream=stream,
226+
timeout=timeout,
224227
) as response:
225228
response.raise_for_status()
226229

@@ -356,10 +359,10 @@ def stream_generator(url, headers, payload, stream):
356359
payload = json.dumps(payload_dict)
357360

358361
if stream:
359-
return stream_generator(self.url, headers, payload, stream)
362+
return stream_generator(self.url, headers, payload, stream, self.timeout)
360363

361364
else:
362-
return return_dict(self.url, headers, payload)
365+
return return_dict(self.url, headers, payload, self.timeout)
363366

364367
def list_models(self, capability: Optional[str] = "chat-completion") -> List[str]:
365368
# Get the list of current models.
@@ -376,7 +379,7 @@ def list_models(self, capability: Optional[str] = "chat-completion") -> List[str
376379
else:
377380
model_path = "/models/" + capability
378381

379-
response = requests.request("GET", self.url + model_path, headers=headers)
382+
response = requests.request("GET", self.url + model_path, headers=headers, timeout=self.timeout)
380383

381384
response_list = []
382385
for model in response.json()["data"]:

predictionguard/src/completions.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ class Completions:
4141
))
4242
"""
4343

44-
def __init__(self, api_key, url):
44+
def __init__(self, api_key, url, timeout):
4545
self.api_key = api_key
4646
self.url = url
47+
self.timeout = timeout
4748

4849
def create(
4950
self,
@@ -132,9 +133,9 @@ def _generate_completion(
132133
Function to generate a single completion.
133134
"""
134135

135-
def return_dict(url, headers, payload):
136+
def return_dict(url, headers, payload, timeout):
136137
response = requests.request(
137-
"POST", url + "/completions", headers=headers, data=payload
138+
"POST", url + "/completions", headers=headers, data=payload, timeout=timeout
138139
)
139140
# If the request was successful, print the proxies.
140141
if response.status_code == 200:
@@ -155,12 +156,13 @@ def return_dict(url, headers, payload):
155156
pass
156157
raise ValueError("Could not make prediction. " + err)
157158

158-
def stream_generator(url, headers, payload, stream):
159+
def stream_generator(url, headers, payload, stream, timeout):
159160
with requests.post(
160161
url + "/completions",
161162
headers=headers,
162163
data=payload,
163164
stream=stream,
165+
timeout=timeout
164166
) as response:
165167
response.raise_for_status()
166168

@@ -215,10 +217,10 @@ def stream_generator(url, headers, payload, stream):
215217
payload = json.dumps(payload_dict)
216218

217219
if stream:
218-
return stream_generator(self.url, headers, payload, stream)
220+
return stream_generator(self.url, headers, payload, stream, self.timeout)
219221

220222
else:
221-
return return_dict(self.url, headers, payload)
223+
return return_dict(self.url, headers, payload, self.timeout)
222224

223225
def list_models(self) -> List[str]:
224226
# Get the list of current models.
@@ -228,7 +230,7 @@ def list_models(self) -> List[str]:
228230
"User-Agent": "Prediction Guard Python Client: " + __version__,
229231
}
230232

231-
response = requests.request("GET", self.url + "/models/completion", headers=headers)
233+
response = requests.request("GET", self.url + "/models/completion", headers=headers, timeout=self.timeout)
232234

233235
response_list = []
234236
for model in response.json()["data"]:

predictionguard/src/detokenize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ class Detokenize:
4141
"""
4242

4343

44-
def __init__(self, api_key, url):
44+
def __init__(self, api_key, url, timeout):
4545
self.api_key = api_key
4646
self.url = url
47+
self.timeout = timeout
4748

4849
def create(self, model: str, tokens: List[int]) -> Dict[str, Any]:
4950
"""
@@ -85,7 +86,7 @@ def _create_tokens(self, model, tokens):
8586
payload = json.dumps(payload)
8687

8788
response = requests.request(
88-
"POST", self.url + "/detokenize", headers=headers, data=payload
89+
"POST", self.url + "/detokenize", headers=headers, data=payload, timeout=self.timeout
8990
)
9091

9192
if response.status_code == 200:
@@ -114,7 +115,7 @@ def list_models(self):
114115
"User-Agent": "Prediction Guard Python Client: " + __version__
115116
}
116117

117-
response = requests.request("GET", self.url + "/models/detokenize", headers=headers)
118+
response = requests.request("GET", self.url + "/models/detokenize", headers=headers, timeout=self.timeout)
118119

119120
response_list = []
120121
for model in response.json()["data"]:

predictionguard/src/documents.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,18 @@ class Documents:
3737
))
3838
"""
3939

40-
def __init__(self, api_key, url):
40+
def __init__(self, api_key, url, timeout):
4141
self.api_key = api_key
4242
self.url = url
43+
self.timeout = timeout
4344

44-
self.extract: DocumentsExtract = DocumentsExtract(self.api_key, self.url)
45+
self.extract: DocumentsExtract = DocumentsExtract(self.api_key, self.url, self.timeout)
4546

4647
class DocumentsExtract:
47-
def __init__(self, api_key, url):
48+
def __init__(self, api_key, url, timeout):
4849
self.api_key = api_key
4950
self.url = url
51+
self.timeout = timeout
5052

5153
def create(
5254
self,
@@ -117,7 +119,7 @@ def _extract_documents(
117119

118120
response = requests.request(
119121
"POST", self.url + "/documents/extract",
120-
headers=headers, files=files, data=data
122+
headers=headers, files=files, data=data, timeout=self.timeout
121123
)
122124

123125
# If the request was successful, print the proxies.

predictionguard/src/embeddings.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ class Embeddings:
4646
))
4747
"""
4848

49-
def __init__(self, api_key, url):
49+
def __init__(self, api_key, url, timeout):
5050
self.api_key = api_key
5151
self.url = url
52+
self.timeout = timeout
5253

5354
def create(
5455
self,
@@ -166,7 +167,7 @@ def _generate_embeddings(self, model, input, truncate, truncation_direction):
166167

167168
payload = json.dumps(payload_dict)
168169
response = requests.request(
169-
"POST", self.url + "/embeddings", headers=headers, data=payload
170+
"POST", self.url + "/embeddings", headers=headers, data=payload, timeout=self.timeout
170171
)
171172

172173
# If the request was successful, print the proxies.
@@ -204,7 +205,7 @@ def list_models(self, capability: Optional[str] = "embedding") -> List[str]:
204205
else:
205206
model_path = "/models/" + capability
206207

207-
response = requests.request("GET", self.url + model_path, headers=headers)
208+
response = requests.request("GET", self.url + model_path, headers=headers, timeout=self.timeout)
208209

209210
response_list = []
210211
for model in response.json()["data"]:

predictionguard/src/factuality.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ class Factuality:
4141
))
4242
"""
4343

44-
def __init__(self, api_key, url):
44+
def __init__(self, api_key, url, timeout):
4545
self.api_key = api_key
4646
self.url = url
47+
self.timeout = timeout
4748

4849
def check(self, reference: str, text: str) -> Dict[str, Any]:
4950
"""
@@ -72,7 +73,7 @@ def _generate_score(self, reference, text):
7273
payload_dict = {"reference": reference, "text": text}
7374
payload = json.dumps(payload_dict)
7475
response = requests.request(
75-
"POST", self.url + "/factuality", headers=headers, data=payload
76+
"POST", self.url + "/factuality", headers=headers, data=payload, timeout=self.timeout
7677
)
7778

7879
# If the request was successful, print the proxies.

0 commit comments

Comments
 (0)