Skip to content

Commit 7f97c57

Browse files
authored
Merge pull request #37 from predictionguard/jacob/vllm-support
Adding support for new parameters in vLLM
2 parents 1c0d720 + 2f13742 commit 7f97c57

File tree

16 files changed

+478
-40
lines changed

16 files changed

+478
-40
lines changed

fixtures/test_audio.wav

344 KB
Binary file not shown.

fixtures/test_csv.csv

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
John,Doe,120 jefferson st.,Riverside, NJ, 08075
2+
Jack,McGinnis,220 hobo Av.,Phila, PA,09119
3+
"John ""Da Man""",Repici,120 Jefferson St.,Riverside, NJ,08075
4+
Stephen,Tyler,"7452 Terrace ""At the Plaza"" road",SomeTown,SD, 91234
5+
,Blankman,,SomeTown, SD, 00298
6+
"Joan ""the bone"", Anne",Jet,"9th, at Terrace plc",Desert City,CO,00123

fixtures/test_pdf.pdf

18.4 KB
Binary file not shown.

predictionguard/client.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import requests
44
from typing import Optional
55

6+
from .src.audio import Audio
67
from .src.chat import Chat
78
from .src.completions import Completions
9+
from .src.documents import Documents
810
from .src.embeddings import Embeddings
911
from .src.rerank import Rerank
1012
from .src.tokenize import Tokenize
@@ -17,9 +19,9 @@
1719
from .version import __version__
1820

1921
__all__ = [
20-
"PredictionGuard", "Chat", "Completions", "Embeddings", "Rerank",
21-
"Tokenize", "Translate", "Factuality", "Toxicity", "Pii", "Injection",
22-
"Models"
22+
"PredictionGuard", "Chat", "Completions", "Embeddings",
23+
"Audio", "Documents", "Rerank", "Tokenize", "Translate",
24+
"Factuality", "Toxicity", "Pii", "Injection", "Models"
2325
]
2426

2527
class PredictionGuard:
@@ -65,6 +67,12 @@ def __init__(
6567
self.embeddings: Embeddings = Embeddings(self.api_key, self.url)
6668
"""Embedding generates chat completions based on a conversation history."""
6769

70+
self.audio: Audio = Audio(self.api_key, self.url)
71+
"""Audio allows for the transcription of audio files."""
72+
73+
self.documents: Documents = Documents(self.api_key, self.url)
74+
"""Documents allows you to extract text from various document file types."""
75+
6876
self.rerank: Rerank = Rerank(self.api_key, self.url)
6977
"""Rerank sorts text inputs by semantic relevance to a specified query."""
7078

predictionguard/src/audio.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import json
2+
3+
import requests
4+
from typing import Any, Dict, Optional
5+
6+
from ..version import __version__
7+
8+
9+
class Audio:
10+
"""Audio generates a response based on audio data.
11+
12+
Usage::
13+
14+
import os
15+
import json
16+
17+
from predictionguard import PredictionGuard
18+
19+
# Set your Prediction Guard token as an environmental variable.
20+
os.environ["PREDICTIONGUARD_API_KEY"] = "<api key>"
21+
22+
client = PredictionGuard()
23+
24+
result = client.audio.transcriptions.create(
25+
model="whisper-3-large-instruct", file=sample_audio.wav
26+
)
27+
28+
print(json.dumps(result, sort_keys=True, indent=4, separators=(",", ": ")))
29+
"""
30+
31+
def __init__(self, api_key, url):
32+
self.api_key = api_key
33+
self.url = url
34+
35+
self.transcriptions: AudioTranscriptions = AudioTranscriptions(self.api_key, self.url)
36+
37+
class AudioTranscriptions:
38+
def __init__(self, api_key, url):
39+
self.api_key = api_key
40+
self.url = url
41+
42+
def create(
43+
self,
44+
model: str,
45+
file: str
46+
) -> Dict[str, Any]:
47+
"""
48+
Creates a audio transcription request to the Prediction Guard /audio/transcriptions API
49+
50+
:param model: The model to use
51+
:param file: Audio file to be transcribed
52+
:result: A dictionary containing the transcribed text.
53+
"""
54+
55+
# Create a list of tuples, each containing all the parameters for
56+
# a call to _transcribe_audio
57+
args = (model, file)
58+
59+
# Run _transcribe_audio
60+
choices = self._transcribe_audio(*args)
61+
return choices
62+
63+
def _transcribe_audio(self, model, file):
64+
"""
65+
Function to transcribe an audio file.
66+
"""
67+
68+
headers = {
69+
"Authorization": "Bearer " + self.api_key,
70+
"User-Agent": "Prediction Guard Python Client: " + __version__,
71+
}
72+
73+
with open(file, "rb") as audio_file:
74+
files = {"file": (file, audio_file, "audio/wav")}
75+
data = {"model": model}
76+
77+
response = requests.request(
78+
"POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data
79+
)
80+
81+
# If the request was successful, print the proxies.
82+
if response.status_code == 200:
83+
ret = response.json()
84+
return ret
85+
elif response.status_code == 429:
86+
raise ValueError(
87+
"Could not connect to Prediction Guard API. "
88+
"Too many requests, rate limit or quota exceeded."
89+
)
90+
else:
91+
# Check if there is a json body in the response. Read that in,
92+
# print out the error field in the json body, and raise an exception.
93+
err = ""
94+
try:
95+
err = response.json()["error"]
96+
except Exception:
97+
pass
98+
raise ValueError("Could not transcribe the audio file. " + err)

predictionguard/src/chat.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Chat:
4444
{
4545
"role": "user",
4646
"content": "Haha. Good one."
47-
},
47+
}
4848
]
4949
5050
result = client.chat.completions.create(
@@ -69,15 +69,36 @@ def __init__(self, api_key, url):
6969
def create(
7070
self,
7171
model: str,
72-
messages: Union[str, List[Dict[str, Any]]],
72+
messages: Union[
73+
str, List[
74+
Dict[str, Any]
75+
]
76+
],
7377
input: Optional[Dict[str, Any]] = None,
7478
output: Optional[Dict[str, Any]] = None,
79+
frequency_penalty: Optional[float] = None,
80+
logit_bias: Optional[
81+
Dict[str, int]
82+
] = None,
7583
max_completion_tokens: Optional[int] = 100,
7684
max_tokens: Optional[int] = None,
85+
parallel_tool_calls: Optional[bool] = None,
86+
presence_penalty: Optional[float] = None,
87+
stop: Optional[
88+
Union[
89+
str, List[str]
90+
]
91+
] = None,
92+
stream: Optional[bool] = False,
7793
temperature: Optional[float] = 1.0,
94+
tool_choice: Optional[Union[
95+
str, Dict[
96+
str, Dict[str, str]
97+
]
98+
]] = "none",
99+
tools: Optional[List[Dict[str, Union[str, Dict[str, str]]]]] = None,
78100
top_p: Optional[float] = 0.99,
79101
top_k: Optional[float] = 50,
80-
stream: Optional[bool] = False,
81102
) -> Dict[str, Any]:
82103
"""
83104
Creates a chat request for the Prediction Guard /chat API.
@@ -86,11 +107,18 @@ def create(
86107
:param messages: The content of the call, an array of dictionaries containing a role and content.
87108
:param input: A dictionary containing the PII and injection arguments.
88109
:param output: A dictionary containing the consistency, factuality, and toxicity arguments.
110+
:param frequency_penalty: The frequency penalty to use.
111+
:param logit_bias: The logit bias to use.
89112
:param max_completion_tokens: The maximum amount of tokens the model should return.
113+
:param parallel_tool_calls: The parallel tool calls to use.
114+
:param presence_penalty: The presence penalty to use.
115+
:param stop: The completion stopping criteria.
116+
:param stream: Option to stream the API response
90117
:param temperature: The consistency of the model responses to the same prompt. The higher the more consistent.
118+
:param tool_choice: The tool choice to use.
119+
:param tools: Options to pass to the tool choice.
91120
:param top_p: The sampling for the model to use.
92121
:param top_k: The Top-K sampling for the model to use.
93-
:param stream: Option to stream the API response
94122
:return: A dictionary containing the chat response.
95123
"""
96124

@@ -110,11 +138,18 @@ def create(
110138
messages,
111139
input,
112140
output,
141+
frequency_penalty,
142+
logit_bias,
113143
max_completion_tokens,
144+
parallel_tool_calls,
145+
presence_penalty,
146+
stop,
147+
stream,
114148
temperature,
149+
tool_choice,
150+
tools,
115151
top_p,
116-
top_k,
117-
stream,
152+
top_k
118153
)
119154

120155
# Run _generate_chat
@@ -128,11 +163,18 @@ def _generate_chat(
128163
messages,
129164
input,
130165
output,
166+
frequency_penalty,
167+
logit_bias,
131168
max_completion_tokens,
169+
parallel_tool_calls,
170+
presence_penalty,
171+
stop,
172+
stream,
132173
temperature,
174+
tool_choice,
175+
tools,
133176
top_p,
134177
top_k,
135-
stream,
136178
):
137179
"""
138180
Function to generate a single chat response.
@@ -257,11 +299,18 @@ def stream_generator(url, headers, payload, stream):
257299
payload_dict = {
258300
"model": model,
259301
"messages": messages,
302+
"frequency_penalty": frequency_penalty,
303+
"logit_bias": logit_bias,
260304
"max_completion_tokens": max_completion_tokens,
305+
"parallel_tool_calls": parallel_tool_calls,
306+
"presence_penalty": presence_penalty,
307+
"stop": stop,
308+
"stream": stream,
261309
"temperature": temperature,
310+
"tool_choice": tool_choice,
311+
"tools": tools,
262312
"top_p": top_p,
263313
"top_k": top_k,
264-
"stream": stream,
265314
}
266315

267316
if input:

0 commit comments

Comments
 (0)