Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions superpipe/clients.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import requests
import os
from openai import OpenAI
from openai import AzureOpenAI
from anthropic import Anthropic
from superpipe.models import *

Expand All @@ -10,10 +11,17 @@
openrouter_models = []


def init_openai(api_key, base_url=None):
openai_client = OpenAI(api_key=api_key, base_url=base_url)
client_for_model[gpt35] = openai_client
client_for_model[gpt4] = openai_client
def init_openai(api_key, base_url=None, api_version=None):
if base_url and api_version:
openai_client = AzureOpenAI(api_key=api_key,
azure_endpoint=base_url,
api_version=api_version)
client_for_model[gpt35] = openai_client
client_for_model[gpt4] = openai_client
else:
openai_client = OpenAI(api_key=api_key)
client_for_model[gpt35] = openai_client
client_for_model[gpt4] = openai_client


def init_anthropic(api_key):
Expand Down Expand Up @@ -45,8 +53,13 @@ def get_client(model):
if client_for_model.get(gpt35) is None or \
client_for_model.get(gpt4) is None:
api_key = os.getenv("OPENAI_API_KEY")
base_url, api_version = None, None
if client_for_model.get("OPENAI_API_BASE") is None:
base_url = os.getenv("OPENAI_API_BASE")
if client_for_model.get("OPEN_API_VERSION") is None:
api_version = os.getenv("OPENAI_API_VERSION")
if api_key is not None:
init_openai(api_key)
init_openai(api_key, base_url, api_version)
if client_for_model.get(claude3_haiku) is None or \
client_for_model.get(claude3_sonnet) is None or \
client_for_model.get(claude3_opus) is None:
Expand Down
4 changes: 4 additions & 0 deletions superpipe/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def get_structured_llm_response_openai(
args: CompletionCreateParamsNonStreaming = {}) -> StructuredLLMResponse:
system = "You are a helpful assistant designed to output JSON."
updated_args = {**args, "response_format": {"type": "json_object"}}

response = get_llm_response_openai(prompt, model, updated_args, system)
if response.error: # models before 1163 do not support response_format param
response = get_llm_response_openai(prompt, model, args, system)

return StructuredLLMResponse(
input_tokens=response.input_tokens,
output_tokens=response.output_tokens,
Expand Down