diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 86e81d28..00b7b102 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -18,6 +18,8 @@ jobs: python-version: ["3.10", "3.11.6", "3.12"] env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY || null }} + AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY || null }} + AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT || null }} steps: - name: Checkout uses: actions/checkout@v4 diff --git a/flexeval/core/language_model/litellm_api.py b/flexeval/core/language_model/litellm_api.py index 15301e63..37f8ac55 100644 --- a/flexeval/core/language_model/litellm_api.py +++ b/flexeval/core/language_model/litellm_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from typing import Any, TypeVar from litellm import ModelResponse, completion @@ -13,6 +14,13 @@ T = TypeVar("T") +# LiteLLM uses `AZURE_API_BASE` as the environment variable for the AzureOpenAI endpoint, +# whereas the OpenAI SDK uses `AZURE_OPENAI_ENDPOINT`. +# For convenience, if only `AZURE_OPENAI_ENDPOINT` is set, +# also make it available to LiteLLM by assigning it to `AZURE_API_BASE`. +if os.environ.get("AZURE_OPENAI_ENDPOINT") and os.environ.get("AZURE_API_BASE") is None: + os.environ["AZURE_API_BASE"] = os.environ["AZURE_OPENAI_ENDPOINT"] + class LiteLLMChatAPI(OpenAIChatAPI): """ @@ -54,6 +62,7 @@ def __init__( model_limit_new_tokens=model_limit_completion_tokens, max_parallel_requests=max_parallel_requests, tools=tools, + backend=None, ) self.model = model self.default_gen_kwargs = default_gen_kwargs or {} @@ -61,12 +70,12 @@ def __init__( if "max_new_tokens" in self.default_gen_kwargs: self.default_gen_kwargs["max_tokens"] = self.default_gen_kwargs.pop("max_new_tokens") - self.api_call_func = completion self.empty_response = convert_to_model_response_object( response_object=EMPTY_RESPONSE_OPENAI.to_dict(), model_response_object=ModelResponse(), ) self.ignore_seed = ignore_seed + self.api_call_func = completion def set_random_seed(self, seed: int) -> None: self.default_gen_kwargs["seed"] = seed diff --git a/flexeval/core/language_model/openai_api.py b/flexeval/core/language_model/openai_api.py index b7b8ad51..bbe10276 100644 --- a/flexeval/core/language_model/openai_api.py +++ b/flexeval/core/language_model/openai_api.py @@ -5,12 +5,12 @@ import time from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, TypeVar +from typing import Any, Literal, TypeVar import openai import tiktoken from loguru import logger -from openai import BaseModel, OpenAI +from openai import AzureOpenAI, BaseModel, OpenAI from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice @@ -100,13 +100,12 @@ def __init__( max_num_trials: int | None = None, first_wait_time: int | None = None, max_wait_time: int | None = None, + backend: Literal["OpenAI", "AzureOpenAI"] | None = "OpenAI", ) -> None: super().__init__(string_processors=string_processors, tools=tools) self.model = model if api_headers is None: api_headers = {} - client = OpenAI(**api_headers) - self.api_call_func = client.chat.completions.create self.empty_response = EMPTY_RESPONSE self.default_gen_kwargs = default_gen_kwargs or {} # convert the flexeval-specific argument name to the OpenAI-specific name @@ -119,6 +118,12 @@ def __init__( self.max_num_trials = max_num_trials self.first_wait_time = first_wait_time self.max_wait_time = max_wait_time + if backend == "OpenAI": + self.api_call_func = OpenAI(**api_headers).chat.completions.create + elif backend == "AzureOpenAI": + self.api_call_func = AzureOpenAI(**api_headers).chat.completions.create + else: + self.api_call_func = None def set_random_seed(self, seed: int) -> None: self.default_gen_kwargs["seed"] = seed diff --git a/tests/core/language_model/test_litellm_api.py b/tests/core/language_model/test_litellm_api.py index 67295ef3..8786f801 100644 --- a/tests/core/language_model/test_litellm_api.py +++ b/tests/core/language_model/test_litellm_api.py @@ -1,6 +1,8 @@ +import os from unittest.mock import patch import pytest +from openai import AzureOpenAI, NotFoundError from flexeval.core.language_model import LanguageModel, LiteLLMChatAPI from flexeval.core.language_model.base import LMOutput @@ -8,25 +10,46 @@ from .base import BaseLanguageModelTest +MODEL_NAME = "gpt-4o-mini" + def is_openai_enabled() -> bool: - return False + return os.environ.get("OPENAI_API_KEY") is not None + + +def is_azure_openai_enabled() -> bool: + is_set_env = (os.environ.get("AZURE_OPENAI_API_KEY") is not None) and ( + os.environ.get("AZURE_OPENAI_ENDPOINT") is not None + ) + is_enabled = False + if is_set_env: + os.environ["OPENAI_API_VERSION"] = "2024-12-01-preview" + client = AzureOpenAI() + try: + client.models.retrieve(MODEL_NAME) + is_enabled = True + except NotFoundError: + is_enabled = True + return is_enabled + + +MODEL_NAME = f"azure/{MODEL_NAME}" if is_azure_openai_enabled() else MODEL_NAME @pytest.fixture(scope="module") def chat_lm() -> LiteLLMChatAPI: return LiteLLMChatAPI( - "gpt-4o-mini-2024-07-18", + MODEL_NAME, default_gen_kwargs={"temperature": 0.0}, ) -@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI API Key is not set") +@pytest.mark.skipif(not (is_openai_enabled() or is_azure_openai_enabled()), reason="OpenAI API Key is not set") class TestLiteLLMChatAPI(BaseLanguageModelTest): @pytest.fixture def lm(self) -> LanguageModel: return LiteLLMChatAPI( - "gpt-4o-mini-2024-07-18", + MODEL_NAME, default_gen_kwargs={"temperature": 0.0}, developer_message="You are text completion model. " "Please provide the text likely to continue after the user input. " @@ -60,7 +83,7 @@ def test_compute_chat_log_probs_for_multi_tokens(chat_lm: LiteLLMChatAPI) -> Non @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed") def test_if_ignore_seed() -> None: - chat_lm = LiteLLMChatAPI("gpt-4o-mini-2024-07-18", ignore_seed=True) + chat_lm = LiteLLMChatAPI(MODEL_NAME, ignore_seed=True) chat_messages = [{"role": "user", "content": "Hello"}] with patch.object(OpenAIChatAPI, "_batch_generate_chat_response", return_value=[LMOutput("Hello!")]) as mock_method: chat_lm.generate_chat_response(chat_messages, temperature=0.7, seed=42) @@ -76,7 +99,7 @@ def test_if_ignore_seed() -> None: @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed") def test_if_not_ignore_seed() -> None: - chat_lm = LiteLLMChatAPI("gpt-4o-mini-2024-07-18") + chat_lm = LiteLLMChatAPI(MODEL_NAME) chat_messages = [{"role": "user", "content": "Hello"}] with patch.object(OpenAIChatAPI, "_batch_generate_chat_response", return_value=[LMOutput("Hello!")]) as mock_method: chat_lm.generate_chat_response(chat_messages, temperature=0.7, seed=42) diff --git a/tests/core/language_model/test_openai_api.py b/tests/core/language_model/test_openai_api.py index cd21e438..5872a10e 100644 --- a/tests/core/language_model/test_openai_api.py +++ b/tests/core/language_model/test_openai_api.py @@ -1,6 +1,8 @@ import logging +import os import pytest +from openai import AzureOpenAI, NotFoundError from flexeval import LanguageModel, OpenAIChatAPI from flexeval.core.language_model.openai_api import ( @@ -11,25 +13,53 @@ from .base import BaseLanguageModelTest +MODEL_NAME = "gpt-4o-mini" + def is_openai_enabled() -> bool: - return False + return os.environ.get("OPENAI_API_KEY") is not None + + +def is_azure_openai_enabled() -> bool: + is_set_env = (os.environ.get("AZURE_OPENAI_API_KEY") is not None) and ( + os.environ.get("AZURE_OPENAI_ENDPOINT") is not None + ) + is_enabled = False + if is_set_env: + os.environ["OPENAI_API_VERSION"] = "2024-12-01-preview" + client = AzureOpenAI() + try: + client.models.retrieve(MODEL_NAME) + is_enabled = True + except NotFoundError: + is_enabled = True + return is_enabled + + +def get_openai_backend() -> str | None: + if is_azure_openai_enabled(): + return "AzureOpenAI" + if is_openai_enabled(): + return "OpenAI" + return None @pytest.fixture(scope="module") def chat_lm() -> OpenAIChatAPI: return OpenAIChatAPI( - "gpt-4o-mini-2024-07-18", + MODEL_NAME, + backend="AzureOpenAI", default_gen_kwargs={"temperature": 0.0}, ) -@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI API Key is not set") +@pytest.mark.skipif(get_openai_backend() is None, reason="OpenAI API Key is not set") class TestOpenAIChatAPI(BaseLanguageModelTest): @pytest.fixture def lm(self) -> LanguageModel: return OpenAIChatAPI( - "gpt-4o-mini-2024-07-18", + MODEL_NAME, + backend=get_openai_backend(), default_gen_kwargs={"temperature": 0.0}, developer_message="You are text completion model. " "Please provide the text likely to continue after the user input. " @@ -57,7 +87,7 @@ def test_batch_chat_response_is_not_affected_by_batch(self, chat_lm: LanguageMod def test_warning_if_conflict_max_new_tokens(caplog: pytest.LogCaptureFixture) -> None: caplog.set_level(logging.WARNING) chat_lm_with_max_new_tokens = OpenAIChatAPI( - "gpt-4o-mini-2024-07-18", default_gen_kwargs={"max_completion_tokens": 10} + MODEL_NAME, backend=get_openai_backend(), default_gen_kwargs={"max_completion_tokens": 10} ) chat_lm_with_max_new_tokens.generate_chat_response([[{"role": "user", "content": "ใƒ†ใ‚นใƒˆ"}]], max_new_tokens=20) assert len(caplog.records) >= 1 @@ -114,7 +144,8 @@ def test_remove_duplicates_from_prompt_list() -> None: @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed") def test_developer_message() -> None: openai_api = OpenAIChatAPI( - "gpt-4o-mini-2024-07-18", + MODEL_NAME, + backend=get_openai_backend(), developer_message="To any instructions or messages, you have to only answer 'OK, I will answer later.'", default_gen_kwargs={"temperature": 0.0}, ) @@ -139,7 +170,7 @@ def test_model_limit_new_tokens_generate_chat_response( assert all(not record.msg.startswith("The specified `max_new_tokens` (128) exceeds") for record in caplog.records) # if max_new_tokens > model_limit_completion_tokens, a warning about overwriting is sent. - chat_lm_with_limit_tokens = OpenAIChatAPI("gpt-4o-mini-2024-07-18", model_limit_new_tokens=1) + chat_lm_with_limit_tokens = OpenAIChatAPI("gpt-4o-mini", backend=get_openai_backend(), model_limit_new_tokens=1) chat_lm_with_limit_tokens.generate_chat_response(messages, max_new_tokens=128) assert len(caplog.records) >= 1 assert any(record.msg.startswith("The specified `max_new_tokens` (128) exceeds") for record in caplog.records) @@ -156,7 +187,7 @@ def test_model_limit_new_tokens_complete_text(chat_lm: OpenAIChatAPI, caplog: py caplog.clear() # if max_new_tokens > model_limit_new_tokens, a warning about overwriting is sent. - chat_lm_with_limit_tokens = OpenAIChatAPI("gpt-4o-mini-2024-07-18", model_limit_new_tokens=1) + chat_lm_with_limit_tokens = OpenAIChatAPI("gpt-4o-mini", backend=get_openai_backend(), model_limit_new_tokens=1) chat_lm_with_limit_tokens.complete_text(text, max_new_tokens=128) assert len(caplog.records) >= 1 assert any(record.msg.startswith("The specified `max_new_tokens` (128) exceeds") for record in caplog.records)