11import os
22
33import requests
4- from typing import Optional
4+ from typing import Optional , Union
55
66from .src .audio import Audio
77from .src .chat import Chat
@@ -30,11 +30,15 @@ class PredictionGuard:
3030 """PredictionGuard provides access the Prediction Guard API."""
3131
3232 def __init__ (
33- self , api_key : Optional [str ] = None , url : Optional [str ] = None
33+ self ,
34+ api_key : Optional [str ] = None ,
35+ url : Optional [str ] = None ,
36+ timeout : Optional [Union [int , float ]] = None
3437 ) -> None :
3538 """
3639 :param api_key: api_key represents PG api key.
3740 :param url: url represents the transport and domain:port
41+ :param timeout: request timeout in seconds.
3842 """
3943
4044 # Get the access api_key.
@@ -56,50 +60,67 @@ def __init__(
5660 url = "https://api.predictionguard.com"
5761 self .url = url
5862
63+ if not timeout :
64+ timeout = os .environ .get ("TIMEOUT" )
65+ if not timeout :
66+ timeout = None
67+ if timeout :
68+ try :
69+ timeout = float (timeout )
70+ except ValueError :
71+ raise ValueError (
72+ "Timeout must be of type integer or float, not %s." % (type (timeout ).__name__ ,)
73+ )
74+ except TypeError :
75+ raise TypeError (
76+ "Timeout should be of type integer or float, not %s." % (type (timeout ).__name__ ,)
77+ )
78+ self .timeout = timeout
79+
5980 # Connect to Prediction Guard and set the access api_key.
6081 self ._connect_client ()
6182
6283 # Pass Prediction Guard class variables to inner classes
63- self .chat : Chat = Chat (self .api_key , self .url )
84+ self .chat : Chat = Chat (self .api_key , self .url , self . timeout )
6485 """Chat generates chat completions based on a conversation history"""
6586
66- self .completions : Completions = Completions (self .api_key , self .url )
87+ self .completions : Completions = Completions (self .api_key , self .url , self . timeout )
6788 """Completions generates text completions based on the provided input"""
6889
69- self .embeddings : Embeddings = Embeddings (self .api_key , self .url )
90+ self .embeddings : Embeddings = Embeddings (self .api_key , self .url , self . timeout )
7091 """Embedding generates chat completions based on a conversation history."""
7192
72- self .audio : Audio = Audio (self .api_key , self .url )
93+ self .audio : Audio = Audio (self .api_key , self .url , self . timeout )
7394 """Audio allows for the transcription of audio files."""
7495
75- self .documents : Documents = Documents (self .api_key , self .url )
96+ self .documents : Documents = Documents (self .api_key , self .url , self . timeout )
7697 """Documents allows you to extract text from various document file types."""
7798
78- self .rerank : Rerank = Rerank (self .api_key , self .url )
99+ self .rerank : Rerank = Rerank (self .api_key , self .url , self . timeout )
79100 """Rerank sorts text inputs by semantic relevance to a specified query."""
80101
81- self .translate : Translate = Translate (self .api_key , self .url )
102+ self .translate : Translate = Translate (self .api_key , self .url , self . timeout )
82103 """Translate converts text from one language to another."""
83104
84- self .factuality : Factuality = Factuality (self .api_key , self .url )
105+ self .factuality : Factuality = Factuality (self .api_key , self .url , self . timeout )
85106 """Factuality checks the factuality of a given text compared to a reference."""
86107
87- self .toxicity : Toxicity = Toxicity (self .api_key , self .url )
108+ self .toxicity : Toxicity = Toxicity (self .api_key , self .url , self . timeout )
88109 """Toxicity checks the toxicity of a given text."""
89110
90- self .pii : Pii = Pii (self .api_key , self .url )
111+ self .pii : Pii = Pii (self .api_key , self .url , self . timeout )
91112 """Pii replaces personal information such as names, SSNs, and emails in a given text."""
92113
93- self .injection : Injection = Injection (self .api_key , self .url )
114+ self .injection : Injection = Injection (self .api_key , self .url , self . timeout )
94115 """Injection detects potential prompt injection attacks in a given prompt."""
95116
96- self .tokenize : Tokenize = Tokenize (self .api_key , self .url )
117+ self .tokenize : Tokenize = Tokenize (self .api_key , self .url , self . timeout )
97118 """Tokenize generates tokens for input text."""
98119
99- self .detokenize : Detokenize = Detokenize (self .api_key , self .url )
120+ self .detokenize : Detokenize = Detokenize (self .api_key , self .url , self . timeout )
100121 """Detokenizes generates text for input tokens."""
101122
102- self .models : Models = Models (self .api_key , self .url )
123+ self .models : Models = Models (self .api_key , self .url , self . timeout )
103124 """Models lists all of the models available in the Prediction Guard API."""
104125
105126 def _connect_client (self ) -> None :
@@ -112,7 +133,7 @@ def _connect_client(self) -> None:
112133 }
113134
114135 # Try listing models to make sure we can connect.
115- response = requests .request ("GET" , self .url + "/completions" , headers = headers )
136+ response = requests .request ("GET" , self .url + "/completions" , headers = headers , timeout = self . timeout )
116137
117138 # If the connection was unsuccessful, raise an exception.
118139 if response .status_code == 200 :
0 commit comments