Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLflow tracing for langchain and openai implementations of VectorSearchRetrieverTool #43

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from databricks_ai_bridge.vector_search_retriever_tool import (
VectorSearchRetrieverToolInput,
VectorSearchRetrieverToolMixin,
vector_search_retriever_tool_trace,
)
from langchain_core.embeddings import Embeddings
from langchain_core.tools import BaseTool
Expand Down Expand Up @@ -54,6 +55,7 @@ def _validate_tool_inputs(self):

return self

@vector_search_retriever_tool_trace
def _run(self, query: str) -> str:
return self._vector_store.similarity_search(
query, k=self.num_results, filter=self.filters, query_type=self.query_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import json
from typing import Any, Dict, List, Optional

import mlflow
import pytest
from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401
ALL_INDEX_NAMES,
DELTA_SYNC_INDEX,
INPUT_TEXTS,
mock_vs_client,
mock_workspace_client,
)
from langchain_core.embeddings import Embeddings
from langchain_core.tools import BaseTool
from mlflow.entities import SpanType

from databricks_langchain import ChatDatabricks, VectorSearchRetrieverTool
from tests.utils.chat_models import llm, mock_client # noqa: F401
Expand Down Expand Up @@ -103,3 +107,17 @@ def test_vector_search_retriever_tool_description_generation(index_name: str) ->
"The string used to query the index with and identify the most similar "
"vectors and return the associated documents."
)


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
@pytest.mark.parametrize("tool_name", [None, "test_tool"])
def test_vs_tool_tracing(index_name: str, tool_name: Optional[str]) -> None:
vector_search_tool = init_vector_search_tool(index_name, tool_name=tool_name)
vector_search_tool._run("Databricks Agent Framework")
trace = mlflow.get_last_active_trace()
spans = trace.search_spans(name=tool_name or index_name, span_type=SpanType.RETRIEVER)
assert len(spans) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we validate that the span logs the correct/expected input and output?

inputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanInputs"])
assert inputs["query"] == "Databricks Agent Framework"
outputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanOutputs"])
assert [d["page_content"] in INPUT_TEXTS for d in outputs]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from databricks_ai_bridge.vector_search_retriever_tool import (
VectorSearchRetrieverToolInput,
VectorSearchRetrieverToolMixin,
vector_search_retriever_tool_trace,
)
from pydantic import Field, PrivateAttr, model_validator

Expand Down Expand Up @@ -84,6 +85,7 @@ def rewrite_index_name(index_name: str):
)
return self

@vector_search_retriever_tool_trace
def execute_calls(
self,
response: ChatCompletion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch

import mlflow
import pytest
from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401
ALL_INDEX_NAMES,
Expand All @@ -10,6 +11,7 @@
mock_vs_client,
mock_workspace_client,
)
from mlflow.entities import SpanType
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
Expand Down Expand Up @@ -170,3 +172,42 @@ def test_open_ai_client_from_env(
openai_client=self_managed_embeddings_test.open_ai_client,
)
assert response is not None


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
@pytest.mark.parametrize("columns", [None, ["id", "text"]])
@pytest.mark.parametrize("tool_name", [None, "test_tool"])
@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"])
def test_vector_search_retriever_tool_init(
index_name: str,
columns: Optional[List[str]],
tool_name: Optional[str],
tool_description: Optional[str],
) -> None:
if index_name == DELTA_SYNC_INDEX:
self_managed_embeddings_test = SelfManagedEmbeddingsTest()
else:
from openai import OpenAI

self_managed_embeddings_test = SelfManagedEmbeddingsTest(
"text", "text-embedding-3-small", OpenAI(api_key="your-api-key")
)

vector_search_tool = init_vector_search_tool(
index_name=index_name,
columns=columns,
tool_name=tool_name,
tool_description=tool_description,
text_column=self_managed_embeddings_test.text_column,
)
assert isinstance(vector_search_tool, BaseModel)
# simulate call to openai.chat.completions.create
chat_completion_resp = get_chat_completion_response(tool_name, index_name)
vector_search_tool.execute_calls(
chat_completion_resp,
embedding_model_name=self_managed_embeddings_test.embedding_model_name,
openai_client=self_managed_embeddings_test.open_ai_client,
)
trace = mlflow.get_last_active_trace()
spans = trace.search_spans(name=tool_name or index_name, span_type=SpanType.RETRIEVER)
assert len(spans) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also validate input/output here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave this as is for now. Discussed offline, but as it stands the current span does not have the correct input/output until this PR merges. I will wait to update the unit test.

20 changes: 20 additions & 0 deletions src/databricks_ai_bridge/vector_search_retriever_tool.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
from functools import wraps
from typing import Any, Dict, List, Optional

import mlflow
from mlflow.entities import SpanType
from pydantic import BaseModel, Field

from databricks_ai_bridge.utils.vector_search import IndexDetails

DEFAULT_TOOL_DESCRIPTION = "A vector search-based retrieval tool for querying indexed embeddings."


def vector_search_retriever_tool_trace(func):
"""
Decorator factory to trace VectorSearchRetrieverTool with the tool name
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# Create a new decorator with the instance's name
traced_func = mlflow.trace(
name=self.tool_name or self.index_name, span_type=SpanType.RETRIEVER
)(func)
# Call the traced function with self
return traced_func(self, *args, **kwargs)

return wrapper


class VectorSearchRetrieverToolInput(BaseModel):
query: str = Field(
description="The string used to query the index with and identify the most similar "
Expand Down
Loading