Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ jobs:
python-version: ["3.10", "3.11.6", "3.12"]
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY || null }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY || null }}
AZURE_OPENAI_ENDPOINT: ${{ secrets.$AZURE_OPENAI_ENDPOINT || null }}
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
11 changes: 10 additions & 1 deletion flexeval/core/language_model/litellm_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from typing import Any, TypeVar

from litellm import ModelResponse, completion
Expand All @@ -13,6 +14,13 @@

T = TypeVar("T")

# LiteLLM uses `AZURE_API_BASE` as the environment variable for the AzureOpenAI endpoint,
# whereas the OpenAI SDK uses `AZURE_OPENAI_ENDPOINT`.
# For convenience, if only `AZURE_OPENAI_ENDPOINT` is set,
# also make it available to LiteLLM by assigning it to `AZURE_API_BASE`.
if os.environ.get("AZURE_OPENAI_ENDPOINT") and os.environ.get("AZURE_API_BASE") is None:
os.environ["AZURE_API_BASE"] = os.environ["AZURE_OPENAI_ENDPOINT"]


class LiteLLMChatAPI(OpenAIChatAPI):
"""
Expand Down Expand Up @@ -54,19 +62,20 @@ def __init__(
model_limit_new_tokens=model_limit_completion_tokens,
max_parallel_requests=max_parallel_requests,
tools=tools,
backend=None,
)
self.model = model
self.default_gen_kwargs = default_gen_kwargs or {}
# convert the flexeval-specific argument name to the OpenAI-specific name
if "max_new_tokens" in self.default_gen_kwargs:
self.default_gen_kwargs["max_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")

self.api_call_func = completion
self.empty_response = convert_to_model_response_object(
response_object=EMPTY_RESPONSE_OPENAI.to_dict(),
model_response_object=ModelResponse(),
)
self.ignore_seed = ignore_seed
self.api_call_func = completion

def set_random_seed(self, seed: int) -> None:
self.default_gen_kwargs["seed"] = seed
Expand Down
13 changes: 9 additions & 4 deletions flexeval/core/language_model/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, TypeVar
from typing import Any, Literal, TypeVar

import openai
import tiktoken
from loguru import logger
from openai import BaseModel, OpenAI
from openai import AzureOpenAI, BaseModel, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice

Expand Down Expand Up @@ -100,13 +100,12 @@ def __init__(
max_num_trials: int | None = None,
first_wait_time: int | None = None,
max_wait_time: int | None = None,
backend: Literal["OpenAI", "AzureOpenAI"] | None = "OpenAI",
) -> None:
super().__init__(string_processors=string_processors, tools=tools)
self.model = model
if api_headers is None:
api_headers = {}
client = OpenAI(**api_headers)
self.api_call_func = client.chat.completions.create
self.empty_response = EMPTY_RESPONSE
self.default_gen_kwargs = default_gen_kwargs or {}
# convert the flexeval-specific argument name to the OpenAI-specific name
Expand All @@ -119,6 +118,12 @@ def __init__(
self.max_num_trials = max_num_trials
self.first_wait_time = first_wait_time
self.max_wait_time = max_wait_time
if backend == "OpenAI":
self.api_call_func = OpenAI(**api_headers).chat.completions.create
elif backend == "AzureOpenAI":
self.api_call_func = AzureOpenAI(**api_headers).chat.completions.create
else:
self.api_call_func = None
Comment on lines +121 to +126
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, the OpenAI client was invoked even when using LiteLLMChatAPI, requiring OPENAI_API_KEY to be set. This is avoided by not initializing the client in OpenAIChatAPI when backend == None.


def set_random_seed(self, seed: int) -> None:
self.default_gen_kwargs["seed"] = seed
Expand Down
34 changes: 28 additions & 6 deletions tests/core/language_model/test_litellm_api.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,54 @@
import os
from unittest.mock import patch

import pytest
from openai import AzureOpenAI, NotFoundError

from flexeval.core.language_model import LanguageModel, LiteLLMChatAPI
from flexeval.core.language_model.base import LMOutput
from flexeval.core.language_model.openai_api import OpenAIChatAPI

from .base import BaseLanguageModelTest

MODEL_NAME = "gpt-4o-mini"


def is_openai_enabled() -> bool:
return False
return os.environ.get("OPENAI_API_KEY") is not None


def is_azure_openai_enabled() -> bool:
is_set_env = (os.environ.get("AZURE_OPENAI_API_KEY") is not None) and (
os.environ.get("AZURE_OPENAI_ENDPOINT") is not None
)
is_enabled = False
if is_set_env:
client = AzureOpenAI()
try:
client.models.retrieve(MODEL_NAME)
is_enabled = True
except NotFoundError:
is_enabled = True
return is_enabled


MODEL_NAME = f"azure/{MODEL_NAME}" if is_azure_openai_enabled() else MODEL_NAME


@pytest.fixture(scope="module")
def chat_lm() -> LiteLLMChatAPI:
return LiteLLMChatAPI(
"gpt-4o-mini-2024-07-18",
MODEL_NAME,
default_gen_kwargs={"temperature": 0.0},
)


@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI API Key is not set")
@pytest.mark.skipif(not (is_openai_enabled() or is_azure_openai_enabled()), reason="OpenAI API Key is not set")
class TestLiteLLMChatAPI(BaseLanguageModelTest):
@pytest.fixture
def lm(self) -> LanguageModel:
return LiteLLMChatAPI(
"gpt-4o-mini-2024-07-18",
MODEL_NAME,
default_gen_kwargs={"temperature": 0.0},
developer_message="You are text completion model. "
"Please provide the text likely to continue after the user input. "
Expand Down Expand Up @@ -60,7 +82,7 @@ def test_compute_chat_log_probs_for_multi_tokens(chat_lm: LiteLLMChatAPI) -> Non

@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed")
def test_if_ignore_seed() -> None:
chat_lm = LiteLLMChatAPI("gpt-4o-mini-2024-07-18", ignore_seed=True)
chat_lm = LiteLLMChatAPI(MODEL_NAME, ignore_seed=True)
chat_messages = [{"role": "user", "content": "Hello"}]
with patch.object(OpenAIChatAPI, "_batch_generate_chat_response", return_value=[LMOutput("Hello!")]) as mock_method:
chat_lm.generate_chat_response(chat_messages, temperature=0.7, seed=42)
Expand All @@ -76,7 +98,7 @@ def test_if_ignore_seed() -> None:

@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed")
def test_if_not_ignore_seed() -> None:
chat_lm = LiteLLMChatAPI("gpt-4o-mini-2024-07-18")
chat_lm = LiteLLMChatAPI(MODEL_NAME)
chat_messages = [{"role": "user", "content": "Hello"}]
with patch.object(OpenAIChatAPI, "_batch_generate_chat_response", return_value=[LMOutput("Hello!")]) as mock_method:
chat_lm.generate_chat_response(chat_messages, temperature=0.7, seed=42)
Expand Down
46 changes: 38 additions & 8 deletions tests/core/language_model/test_openai_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os

import pytest
from openai import AzureOpenAI, NotFoundError

from flexeval import LanguageModel, OpenAIChatAPI
from flexeval.core.language_model.openai_api import (
Expand All @@ -11,25 +13,52 @@

from .base import BaseLanguageModelTest

MODEL_NAME = "gpt-4o-mini"


def is_openai_enabled() -> bool:
return False
return os.environ.get("OPENAI_API_KEY") is not None


def is_azure_openai_enabled() -> bool:
is_set_env = (os.environ.get("AZURE_OPENAI_API_KEY") is not None) and (
os.environ.get("AZURE_OPENAI_ENDPOINT") is not None
)
is_enabled = False
if is_set_env:
client = AzureOpenAI()
try:
client.models.retrieve(MODEL_NAME)
is_enabled = True
except NotFoundError:
is_enabled = True
return is_enabled


def get_openai_backend() -> str | None:
if is_azure_openai_enabled():
return "AzureOpenAI"
if is_openai_enabled():
return "OpenAI"
return None


@pytest.fixture(scope="module")
def chat_lm() -> OpenAIChatAPI:
return OpenAIChatAPI(
"gpt-4o-mini-2024-07-18",
MODEL_NAME,
backend="AzureOpenAI",
default_gen_kwargs={"temperature": 0.0},
)


@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI API Key is not set")
@pytest.mark.skipif(get_openai_backend() is None, reason="OpenAI API Key is not set")
class TestOpenAIChatAPI(BaseLanguageModelTest):
@pytest.fixture
def lm(self) -> LanguageModel:
return OpenAIChatAPI(
"gpt-4o-mini-2024-07-18",
MODEL_NAME,
backend=get_openai_backend(),
default_gen_kwargs={"temperature": 0.0},
developer_message="You are text completion model. "
"Please provide the text likely to continue after the user input. "
Expand Down Expand Up @@ -57,7 +86,7 @@ def test_batch_chat_response_is_not_affected_by_batch(self, chat_lm: LanguageMod
def test_warning_if_conflict_max_new_tokens(caplog: pytest.LogCaptureFixture) -> None:
caplog.set_level(logging.WARNING)
chat_lm_with_max_new_tokens = OpenAIChatAPI(
"gpt-4o-mini-2024-07-18", default_gen_kwargs={"max_completion_tokens": 10}
MODEL_NAME, backend=get_openai_backend(), default_gen_kwargs={"max_completion_tokens": 10}
)
chat_lm_with_max_new_tokens.generate_chat_response([[{"role": "user", "content": "テスト"}]], max_new_tokens=20)
assert len(caplog.records) >= 1
Expand Down Expand Up @@ -114,7 +143,8 @@ def test_remove_duplicates_from_prompt_list() -> None:
@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed")
def test_developer_message() -> None:
openai_api = OpenAIChatAPI(
"gpt-4o-mini-2024-07-18",
MODEL_NAME,
backend=get_openai_backend(),
developer_message="To any instructions or messages, you have to only answer 'OK, I will answer later.'",
default_gen_kwargs={"temperature": 0.0},
)
Expand All @@ -139,7 +169,7 @@ def test_model_limit_new_tokens_generate_chat_response(
assert all(not record.msg.startswith("The specified `max_new_tokens` (128) exceeds") for record in caplog.records)

# if max_new_tokens > model_limit_completion_tokens, a warning about overwriting is sent.
chat_lm_with_limit_tokens = OpenAIChatAPI("gpt-4o-mini-2024-07-18", model_limit_new_tokens=1)
chat_lm_with_limit_tokens = OpenAIChatAPI("gpt-4o-mini", backend=get_openai_backend(), model_limit_new_tokens=1)
chat_lm_with_limit_tokens.generate_chat_response(messages, max_new_tokens=128)
assert len(caplog.records) >= 1
assert any(record.msg.startswith("The specified `max_new_tokens` (128) exceeds") for record in caplog.records)
Expand All @@ -156,7 +186,7 @@ def test_model_limit_new_tokens_complete_text(chat_lm: OpenAIChatAPI, caplog: py
caplog.clear()

# if max_new_tokens > model_limit_new_tokens, a warning about overwriting is sent.
chat_lm_with_limit_tokens = OpenAIChatAPI("gpt-4o-mini-2024-07-18", model_limit_new_tokens=1)
chat_lm_with_limit_tokens = OpenAIChatAPI("gpt-4o-mini", backend=get_openai_backend(), model_limit_new_tokens=1)
chat_lm_with_limit_tokens.complete_text(text, max_new_tokens=128)
assert len(caplog.records) >= 1
assert any(record.msg.startswith("The specified `max_new_tokens` (128) exceeds") for record in caplog.records)
Expand Down
Loading