-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MLflow tracing for langchain and openai implementations of VectorSearchRetrieverTool #43
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
from typing import Any, Dict, List, Optional | ||
from unittest.mock import MagicMock, Mock, patch | ||
|
||
import mlflow | ||
import pytest | ||
from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 | ||
ALL_INDEX_NAMES, | ||
|
@@ -10,6 +11,7 @@ | |
mock_vs_client, | ||
mock_workspace_client, | ||
) | ||
from mlflow.entities import SpanType | ||
from openai.types.chat import ( | ||
ChatCompletion, | ||
ChatCompletionMessage, | ||
|
@@ -170,3 +172,42 @@ def test_open_ai_client_from_env( | |
openai_client=self_managed_embeddings_test.open_ai_client, | ||
) | ||
assert response is not None | ||
|
||
|
||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) | ||
@pytest.mark.parametrize("columns", [None, ["id", "text"]]) | ||
@pytest.mark.parametrize("tool_name", [None, "test_tool"]) | ||
@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) | ||
def test_vector_search_retriever_tool_init( | ||
index_name: str, | ||
columns: Optional[List[str]], | ||
tool_name: Optional[str], | ||
tool_description: Optional[str], | ||
) -> None: | ||
if index_name == DELTA_SYNC_INDEX: | ||
self_managed_embeddings_test = SelfManagedEmbeddingsTest() | ||
else: | ||
from openai import OpenAI | ||
|
||
self_managed_embeddings_test = SelfManagedEmbeddingsTest( | ||
"text", "text-embedding-3-small", OpenAI(api_key="your-api-key") | ||
) | ||
|
||
vector_search_tool = init_vector_search_tool( | ||
index_name=index_name, | ||
columns=columns, | ||
tool_name=tool_name, | ||
tool_description=tool_description, | ||
text_column=self_managed_embeddings_test.text_column, | ||
) | ||
assert isinstance(vector_search_tool, BaseModel) | ||
# simulate call to openai.chat.completions.create | ||
chat_completion_resp = get_chat_completion_response(tool_name, index_name) | ||
vector_search_tool.execute_calls( | ||
chat_completion_resp, | ||
embedding_model_name=self_managed_embeddings_test.embedding_model_name, | ||
openai_client=self_managed_embeddings_test.open_ai_client, | ||
) | ||
trace = mlflow.get_last_active_trace() | ||
spans = trace.search_spans(name=tool_name or index_name, span_type=SpanType.RETRIEVER) | ||
assert len(spans) == 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we also validate input/output here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll leave this as is for now. Discussed offline, but as it stands the current span does not have the correct input/output until this PR merges. I will wait to update the unit test. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we validate that the span logs the correct/expected input and output?