diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 1f17ba1..4df50f0 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -4,6 +4,7 @@ from databricks_ai_bridge.vector_search_retriever_tool import ( VectorSearchRetrieverToolInput, VectorSearchRetrieverToolMixin, + vector_search_retriever_tool_trace, ) from langchain_core.embeddings import Embeddings from langchain_core.tools import BaseTool @@ -54,6 +55,7 @@ def _validate_tool_inputs(self): return self + @vector_search_retriever_tool_trace def _run(self, query: str) -> str: return self._vector_store.similarity_search( query, k=self.num_results, filter=self.filters, query_type=self.query_type diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index ca9cbdc..e59d099 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,14 +1,18 @@ +import json from typing import Any, Dict, List, Optional +import mlflow import pytest from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, DELTA_SYNC_INDEX, + INPUT_TEXTS, mock_vs_client, mock_workspace_client, ) from langchain_core.embeddings import Embeddings from langchain_core.tools import BaseTool +from mlflow.entities import SpanType from databricks_langchain import ChatDatabricks, VectorSearchRetrieverTool from tests.utils.chat_models import llm, mock_client # noqa: F401 @@ -103,3 +107,17 @@ def test_vector_search_retriever_tool_description_generation(index_name: str) -> "The string used to query the index with and identify the most similar " "vectors and return the associated documents." ) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("tool_name", [None, "test_tool"]) +def test_vs_tool_tracing(index_name: str, tool_name: Optional[str]) -> None: + vector_search_tool = init_vector_search_tool(index_name, tool_name=tool_name) + vector_search_tool._run("Databricks Agent Framework") + trace = mlflow.get_last_active_trace() + spans = trace.search_spans(name=tool_name or index_name, span_type=SpanType.RETRIEVER) + assert len(spans) == 1 + inputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanInputs"]) + assert inputs["query"] == "Databricks Agent Framework" + outputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanOutputs"]) + assert [d["page_content"] in INPUT_TEXTS for d in outputs] diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index efb2990..82cf677 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -11,6 +11,7 @@ from databricks_ai_bridge.vector_search_retriever_tool import ( VectorSearchRetrieverToolInput, VectorSearchRetrieverToolMixin, + vector_search_retriever_tool_trace, ) from pydantic import Field, PrivateAttr, model_validator @@ -84,6 +85,7 @@ def rewrite_index_name(index_name: str): ) return self + @vector_search_retriever_tool_trace def execute_calls( self, response: ChatCompletion, diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index c80fd20..60abc4f 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional from unittest.mock import MagicMock, Mock, patch +import mlflow import pytest from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, @@ -10,6 +11,7 @@ mock_vs_client, mock_workspace_client, ) +from mlflow.entities import SpanType from openai.types.chat import ( ChatCompletion, ChatCompletionMessage, @@ -170,3 +172,42 @@ def test_open_ai_client_from_env( openai_client=self_managed_embeddings_test.open_ai_client, ) assert response is not None + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("columns", [None, ["id", "text"]]) +@pytest.mark.parametrize("tool_name", [None, "test_tool"]) +@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) +def test_vector_search_retriever_tool_init( + index_name: str, + columns: Optional[List[str]], + tool_name: Optional[str], + tool_description: Optional[str], +) -> None: + if index_name == DELTA_SYNC_INDEX: + self_managed_embeddings_test = SelfManagedEmbeddingsTest() + else: + from openai import OpenAI + + self_managed_embeddings_test = SelfManagedEmbeddingsTest( + "text", "text-embedding-3-small", OpenAI(api_key="your-api-key") + ) + + vector_search_tool = init_vector_search_tool( + index_name=index_name, + columns=columns, + tool_name=tool_name, + tool_description=tool_description, + text_column=self_managed_embeddings_test.text_column, + ) + assert isinstance(vector_search_tool, BaseModel) + # simulate call to openai.chat.completions.create + chat_completion_resp = get_chat_completion_response(tool_name, index_name) + vector_search_tool.execute_calls( + chat_completion_resp, + embedding_model_name=self_managed_embeddings_test.embedding_model_name, + openai_client=self_managed_embeddings_test.open_ai_client, + ) + trace = mlflow.get_last_active_trace() + spans = trace.search_spans(name=tool_name or index_name, span_type=SpanType.RETRIEVER) + assert len(spans) == 1 diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index a8f2a58..9359529 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -1,5 +1,8 @@ +from functools import wraps from typing import Any, Dict, List, Optional +import mlflow +from mlflow.entities import SpanType from pydantic import BaseModel, Field from databricks_ai_bridge.utils.vector_search import IndexDetails @@ -7,6 +10,23 @@ DEFAULT_TOOL_DESCRIPTION = "A vector search-based retrieval tool for querying indexed embeddings." +def vector_search_retriever_tool_trace(func): + """ + Decorator factory to trace VectorSearchRetrieverTool with the tool name + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + # Create a new decorator with the instance's name + traced_func = mlflow.trace( + name=self.tool_name or self.index_name, span_type=SpanType.RETRIEVER + )(func) + # Call the traced function with self + return traced_func(self, *args, **kwargs) + + return wrapper + + class VectorSearchRetrieverToolInput(BaseModel): query: str = Field( description="The string used to query the index with and identify the most similar "