Skip to content

Commit

Permalink
fix: Fix for byod without data issue (#1355)
Browse files Browse the repository at this point in the history
Co-authored-by: Pavan Kumar <v-kupavan.microsoft.com>
  • Loading branch information
Pavan-Microsoft authored Sep 27, 2024
1 parent 9e8da75 commit 6f03aab
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 20 deletions.
9 changes: 0 additions & 9 deletions code/backend/batch/utilities/helpers/env_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,6 @@ def __load_config(self, **kwargs) -> None:

self.PROMPT_FLOW_DEPLOYMENT_NAME = os.getenv("PROMPT_FLOW_DEPLOYMENT_NAME", "")

def should_use_data(self) -> bool:
if (
self.AZURE_SEARCH_SERVICE
and self.AZURE_SEARCH_INDEX
and (self.AZURE_SEARCH_KEY or self.AZURE_AUTH_TYPE == "rbac")
):
return True
return False

def is_chat_model(self):
if "gpt-4" in self.AZURE_OPENAI_MODEL_NAME.lower():
return True
Expand Down
17 changes: 16 additions & 1 deletion code/create_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dotenv import load_dotenv
from urllib.parse import quote
from backend.batch.utilities.helpers.env_helper import EnvHelper
from backend.batch.utilities.helpers.azure_search_helper import AzureSearchHelper
from backend.batch.utilities.helpers.orchestrator_helper import Orchestrator
from backend.batch.utilities.helpers.config.config_helper import ConfigHelper
from backend.batch.utilities.helpers.config.conversation_flow import ConversationFlow
Expand Down Expand Up @@ -69,6 +70,19 @@ def get_citations(citation_list):
return citations_dict


def should_use_data(
env_helper: EnvHelper, azure_search_helper: AzureSearchHelper
) -> bool:
if (
env_helper.AZURE_SEARCH_SERVICE
and env_helper.AZURE_SEARCH_INDEX
and (env_helper.AZURE_SEARCH_KEY or env_helper.AZURE_AUTH_TYPE == "rbac")
and not azure_search_helper._index_not_exists(env_helper.AZURE_SEARCH_INDEX)
):
return True
return False


def stream_with_data(response: Stream[ChatCompletionChunk]):
"""This function streams the response from Azure OpenAI with data."""
response_obj = {
Expand Down Expand Up @@ -371,6 +385,7 @@ def create_app():

app = Flask(__name__)
env_helper: EnvHelper = EnvHelper()
azure_search_helper: AzureSearchHelper = AzureSearchHelper()

logger.debug("Starting web app")

Expand All @@ -385,7 +400,7 @@ def health():

def conversation_azure_byod():
try:
if env_helper.should_use_data():
if should_use_data(env_helper, azure_search_helper):
return conversation_with_data(request, env_helper)
else:
return conversation_without_data(request, env_helper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,20 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig):
httpserver.check()


@patch(
"backend.batch.utilities.search.azure_search_handler.AzureSearchHelper._index_not_exists"
)
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
)
def test_azure_byod_responds_successfully_when_streaming(
get_active_config_or_default_mock,
index_not_exists_mock,
app_url: str,
app_config: AppConfig,
):
get_active_config_or_default_mock.return_value.prompts.conversational_flow = "byod"

index_not_exists_mock.return_value = False
# when
response = requests.post(f"{app_url}{path}", json=body)

Expand Down Expand Up @@ -92,17 +96,21 @@ def test_azure_byod_responds_successfully_when_streaming(
}


@patch(
"backend.batch.utilities.search.azure_search_handler.AzureSearchHelper._index_not_exists"
)
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
)
def test_post_makes_correct_call_to_azure_openai(
get_active_config_or_default_mock,
index_not_exists_mock,
app_url: str,
app_config: AppConfig,
httpserver: HTTPServer,
):
get_active_config_or_default_mock.return_value.prompts.conversational_flow = "byod"

index_not_exists_mock.return_value = False
# when
requests.post(f"{app_url}{path}", json=body)

Expand Down
48 changes: 40 additions & 8 deletions code/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def env_helper_mock():
)
env_helper.SHOULD_STREAM = True
env_helper.is_auth_type_keys.return_value = True
env_helper.should_use_data.return_value = True
env_helper.CONVERSATION_FLOW = ConversationFlow.CUSTOM.value

yield env_helper
Expand Down Expand Up @@ -599,6 +598,9 @@ def setup_method(self):
),
]

@patch(
"backend.batch.utilities.search.azure_search_handler.AzureSearchHelper._index_not_exists"
)
@patch("create_app.AzureOpenAI")
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
Expand All @@ -611,6 +613,7 @@ def test_conversation_azure_byod_returns_correct_response_when_streaming_with_da
generate_container_sas_mock: MagicMock,
get_active_config_or_default_mock,
azure_openai_mock: MagicMock,
index_not_exists_mock,
env_helper_mock: MagicMock,
client: FlaskClient,
):
Expand All @@ -625,6 +628,7 @@ def test_conversation_azure_byod_returns_correct_response_when_streaming_with_da
"byod"
)
generate_container_sas_mock.return_value = "mock-sas"
index_not_exists_mock.return_value = False

# when
response = client.post(
Expand Down Expand Up @@ -694,6 +698,9 @@ def test_conversation_azure_byod_returns_correct_response_when_streaming_with_da
},
)

@patch(
"backend.batch.utilities.search.azure_search_handler.AzureSearchHelper._index_not_exists"
)
@patch("create_app.AzureOpenAI")
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
Expand All @@ -706,6 +713,7 @@ def test_conversation_azure_byod_returns_correct_response_when_streaming_with_da
generate_container_sas_mock: MagicMock,
get_active_config_or_default_mock,
azure_openai_mock: MagicMock,
index_not_exists_mock,
env_helper_mock: MagicMock,
client: FlaskClient,
):
Expand All @@ -720,6 +728,7 @@ def test_conversation_azure_byod_returns_correct_response_when_streaming_with_da
openai_client_mock.chat.completions.create.return_value = (
self.mock_streamed_response
)
index_not_exists_mock.return_value = False

# when
response = client.post(
Expand Down Expand Up @@ -755,6 +764,9 @@ def test_conversation_azure_byod_returns_correct_response_when_streaming_with_da
"type": "system_assigned_managed_identity",
}

@patch(
"backend.batch.utilities.search.azure_search_handler.AzureSearchHelper._index_not_exists"
)
@patch("create_app.AzureOpenAI")
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
Expand All @@ -767,6 +779,7 @@ def test_conversation_azure_byod_returns_correct_response_when_not_streaming_wit
generate_container_sas_mock: MagicMock,
get_active_config_or_default_mock,
azure_openai_mock: MagicMock,
index_not_exists_mock,
env_helper_mock: MagicMock,
client: FlaskClient,
):
Expand All @@ -777,7 +790,7 @@ def test_conversation_azure_byod_returns_correct_response_when_not_streaming_wit
"byod"
)
generate_container_sas_mock.return_value = "mock-sas"

index_not_exists_mock.return_value = False
openai_client_mock = azure_openai_mock.return_value
openai_client_mock.chat.completions.create.return_value = self.mock_response

Expand Down Expand Up @@ -878,6 +891,9 @@ def test_conversation_azure_byod_returns_500_when_internalservererror_occurs(
"administrator."
}

@patch(
"backend.batch.utilities.search.azure_search_handler.AzureSearchHelper._index_not_exists"
)
@patch("create_app.conversation_with_data")
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
Expand All @@ -886,6 +902,7 @@ def test_conversation_azure_byod_returns_429_on_rate_limit_error(
self,
get_active_config_or_default_mock,
conversation_with_data_mock,
index_not_exists_mock,
client,
):
"""Test that a 429 response is returned on RateLimitError for BYOD conversation."""
Expand All @@ -908,6 +925,7 @@ def test_conversation_azure_byod_returns_429_on_rate_limit_error(
get_active_config_or_default_mock.return_value.prompts.conversational_flow = (
"byod"
)
index_not_exists_mock.return_value = False

# when
response = client.post(
Expand All @@ -923,6 +941,9 @@ def test_conversation_azure_byod_returns_429_on_rate_limit_error(
"Please wait a moment and try again."
}

@patch(
"backend.batch.utilities.search.azure_search_handler.AzureSearchHelper._index_not_exists"
)
@patch("create_app.AzureOpenAI")
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
Expand All @@ -931,17 +952,17 @@ def test_conversation_azure_byod_returns_correct_response_when_not_streaming_wit
self,
get_active_config_or_default_mock,
azure_openai_mock,
index_not_exists_mock,
env_helper_mock,
client,
):
"""Test that the Azure BYOD conversation endpoint returns the correct response."""
# given
env_helper_mock.should_use_data.return_value = False
env_helper_mock.SHOULD_STREAM = False
get_active_config_or_default_mock.return_value.prompts.conversational_flow = (
"byod"
)

index_not_exists_mock.return_value = True
openai_client_mock = MagicMock()
azure_openai_mock.return_value = openai_client_mock

Expand Down Expand Up @@ -997,6 +1018,9 @@ def test_conversation_azure_byod_returns_correct_response_when_not_streaming_wit
stream=False,
)

@patch(
"backend.batch.utilities.search.azure_search_handler.AzureSearchHelper._index_not_exists"
)
@patch("create_app.AzureOpenAI")
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
Expand All @@ -1005,18 +1029,19 @@ def test_conversation_azure_byod_returns_correct_response_when_not_streaming_wit
self,
get_active_config_or_default_mock,
azure_openai_mock,
index_not_exists_mock,
env_helper_mock,
client,
):
"""Test that the Azure BYOD conversation endpoint returns the correct response."""
# given
env_helper_mock.should_use_data.return_value = False
env_helper_mock.SHOULD_STREAM = False
env_helper_mock.AZURE_AUTH_TYPE = "rbac"
env_helper_mock.AZURE_OPENAI_STOP_SEQUENCE = ""
get_active_config_or_default_mock.return_value.prompts.conversational_flow = (
"byod"
)
index_not_exists_mock.return_value = True

openai_client_mock = MagicMock()
azure_openai_mock.return_value = openai_client_mock
Expand Down Expand Up @@ -1073,6 +1098,9 @@ def test_conversation_azure_byod_returns_correct_response_when_not_streaming_wit
stream=False,
)

@patch(
"backend.batch.utilities.search.azure_search_handler.AzureSearchHelper._index_not_exists"
)
@patch("create_app.AzureOpenAI")
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
Expand All @@ -1081,16 +1109,16 @@ def test_conversation_azure_byod_returns_correct_response_when_streaming_without
self,
get_active_config_or_default_mock,
azure_openai_mock,
index_not_exists_mock,
env_helper_mock,
client,
):
"""Test that the Azure BYOD conversation endpoint returns the correct response."""
# given
env_helper_mock.should_use_data.return_value = False
get_active_config_or_default_mock.return_value.prompts.conversational_flow = (
"byod"
)

index_not_exists_mock.return_value = True
openai_client_mock = MagicMock()
azure_openai_mock.return_value = openai_client_mock

Expand Down Expand Up @@ -1120,6 +1148,9 @@ def test_conversation_azure_byod_returns_correct_response_when_streaming_without
== '{"id": "response.id", "model": "mock-openai-model", "created": 0, "object": "response.object", "choices": [{"messages": [{"role": "assistant", "content": "mock content"}]}]}\n'
)

@patch(
"backend.batch.utilities.search.azure_search_handler.AzureSearchHelper._index_not_exists"
)
@patch("create_app.AzureOpenAI")
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
Expand All @@ -1132,6 +1163,7 @@ def test_conversation_azure_byod_uses_semantic_config(
generate_container_sas_mock: MagicMock,
get_active_config_or_default_mock,
azure_openai_mock: MagicMock,
index_not_exists_mock,
client: FlaskClient,
):
"""Test that the Azure BYOD conversation endpoint uses the semantic configuration."""
Expand All @@ -1144,7 +1176,7 @@ def test_conversation_azure_byod_uses_semantic_config(
openai_client_mock.chat.completions.create.return_value = (
self.mock_streamed_response
)

index_not_exists_mock.return_value = False
# when
response = client.post(
"/api/conversation",
Expand Down

0 comments on commit 6f03aab

Please sign in to comment.