Skip to content

Commit d19fe9e

Browse files
authored
VectorSearchRetrieverTool Openai integration (#39)
* Create VectorSearchRetrieverTool class for OpenAI * Intermediate commit * Initial implementation * Working e2e delta sync index happy case * Add unit tests and some validations * Undo line * Remove extra changes * Fix embedding * Remove double field * Lint * Minor cleanup * PR feedback * Lint * Rename tool call
1 parent a8880d1 commit d19fe9e

File tree

8 files changed

+448
-0
lines changed

8 files changed

+448
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# test
22
.pytest_cache/
3+
mlruns/
34

45
# Byte-compiled files
56
__pycache__

integrations/openai/README.md

Whitespace-only changes.

integrations/openai/__init__.py

Whitespace-only changes.

integrations/openai/pyproject.toml

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
[project]
2+
name = "databricks-openai"
3+
version = "0.1.0"
4+
description = "Support for Databricks AI support with OpenAI"
5+
authors = [
6+
{ name="Leon Bi", email="[email protected]" },
7+
]
8+
readme = "README.md"
9+
license = { text="Apache-2.0" }
10+
requires-python = ">=3.9"
11+
dependencies = [
12+
"databricks-vectorsearch>=0.40",
13+
"databricks-ai-bridge>=0.1.0",
14+
"openai>=1.46.1",
15+
]
16+
17+
[project.optional-dependencies]
18+
dev = [
19+
"pytest",
20+
"typing_extensions",
21+
"databricks-sdk>=0.34.0",
22+
"ruff==0.6.4",
23+
]
24+
25+
integration = [
26+
"pytest-timeout>=2.3.1",
27+
]
28+
29+
[build-system]
30+
requires = ["hatchling"]
31+
build-backend = "hatchling.build"
32+
33+
[tool.hatch.build]
34+
include = [
35+
"src/databricks_openai/*"
36+
]
37+
38+
[tool.hatch.build.targets.wheel]
39+
packages = ["src/databricks_openai"]
40+
41+
[tool.ruff]
42+
line-length = 100
43+
target-version = "py39"
44+
45+
[tool.ruff.lint]
46+
select = [
47+
# isort
48+
"I",
49+
# bugbear rules
50+
"B",
51+
# remove unused imports
52+
"F401",
53+
# bare except statements
54+
"E722",
55+
# print statements
56+
"T201",
57+
"T203",
58+
# misuse of typing.TYPE_CHECKING
59+
"TCH004",
60+
# import rules
61+
"TID251",
62+
# undefined-local-with-import-star
63+
"F403",
64+
]
65+
66+
[tool.ruff.format]
67+
docstring-code-format = true
68+
docstring-code-line-length = 88
69+
70+
[tool.ruff.lint.pydocstyle]
71+
convention = "google"

integrations/openai/src/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from databricks_openai.vector_search_retriever_tool import VectorSearchRetrieverTool
2+
3+
# Expose all integrations to users under databricks-openai
4+
__all__ = [
5+
"VectorSearchRetrieverTool",
6+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import json
2+
from typing import Any, Dict, List, Optional, Tuple
3+
4+
from databricks.vector_search.client import VectorSearchIndex
5+
from databricks_ai_bridge.utils.vector_search import (
6+
IndexDetails,
7+
parse_vector_search_response,
8+
validate_and_get_return_columns,
9+
validate_and_get_text_column,
10+
)
11+
from databricks_ai_bridge.vector_search_retriever_tool import (
12+
VectorSearchRetrieverToolInput,
13+
VectorSearchRetrieverToolMixin,
14+
)
15+
from pydantic import Field, PrivateAttr, model_validator
16+
17+
from openai import OpenAI, pydantic_function_tool
18+
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall, ChatCompletionToolParam
19+
20+
21+
class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin):
22+
"""
23+
A utility class to create a vector search-based retrieval tool for querying indexed embeddings.
24+
This class integrates with Databricks Vector Search and provides a convenient interface
25+
for tool calling using the OpenAI SDK.
26+
27+
Example:
28+
dbvs_tool = VectorSearchRetrieverTool("index_name")
29+
tools = [dbvs_tool.tool, ...]
30+
response = openai.chat.completions.create(
31+
model="gpt-4o",
32+
messages=initial_messages,
33+
tools=tools,
34+
)
35+
retriever_call_message = dbvs_tool.execute_calls(response)
36+
37+
### If needed, execute potential remaining tool calls here ###
38+
remaining_tool_call_messages = execute_remaining_tool_calls(response)
39+
40+
final_response = openai.chat.completions.create(
41+
model="gpt-4o",
42+
messages=initial_messages + retriever_call_message + remaining_tool_call_messages,
43+
tools=tools,
44+
)
45+
final_response.choices[0].message.content
46+
"""
47+
48+
text_column: Optional[str] = Field(
49+
None,
50+
description="The name of the text column to use for the embeddings. "
51+
"Required for direct-access index or delta-sync index with "
52+
"self-managed embeddings.",
53+
)
54+
55+
tool: ChatCompletionToolParam = Field(
56+
None, description="The tool input used in the OpenAI chat completion SDK"
57+
)
58+
_index: VectorSearchIndex = PrivateAttr()
59+
_index_details: IndexDetails = PrivateAttr()
60+
61+
@model_validator(mode="after")
62+
def _validate_tool_inputs(self):
63+
from databricks.vector_search.client import (
64+
VectorSearchClient, # import here so we can mock in tests
65+
)
66+
67+
self._index = VectorSearchClient().get_index(index_name=self.index_name)
68+
self._index_details = IndexDetails(self._index)
69+
self.text_column = validate_and_get_text_column(self.text_column, self._index_details)
70+
self.columns = validate_and_get_return_columns(
71+
self.columns or [], self.text_column, self._index_details
72+
)
73+
74+
# OpenAI tool names must match the pattern '^[a-zA-Z0-9_-]+$'."
75+
# The '.' from the index name are not allowed
76+
def rewrite_index_name(index_name: str):
77+
return index_name.replace(".", "_")
78+
79+
self.tool = pydantic_function_tool(
80+
VectorSearchRetrieverToolInput,
81+
name=self.tool_name or rewrite_index_name(self.index_name),
82+
description=self.tool_description
83+
or self._get_default_tool_description(self._index_details),
84+
)
85+
return self
86+
87+
def execute_calls(
88+
self,
89+
response: ChatCompletion,
90+
choice_index: int = 0,
91+
embedding_model_name: str = None,
92+
openai_client: OpenAI = None,
93+
) -> List[Dict[str, Any]]:
94+
"""
95+
Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the
96+
self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages.
97+
98+
Args:
99+
response: The chat completion response object returned by the OpenAI API.
100+
choice_index: The index of the choice to process. Defaults to 0. Note that multiple
101+
choices are not supported yet.
102+
embedding_model_name: The name of the embedding model to use for embedding the query text.
103+
Required for direct access indexes or delta-sync indexes with self-managed embeddings.
104+
openai_client: The OpenAI client object used to generate embeddings for retrieval queries. If not provided,
105+
the default OpenAI client in the current environment will be used.
106+
107+
Returns:
108+
A list of messages containing the assistant message and the retriever call results
109+
that correspond to the self.tool VectorSearchRetrieverToolInput.
110+
"""
111+
112+
def get_query_text_vector(
113+
tool_call: ChatCompletionMessageToolCall,
114+
) -> Tuple[Optional[str], Optional[List[float]]]:
115+
query = json.loads(tool_call.function.arguments)["query"]
116+
if self._index_details.is_databricks_managed_embeddings():
117+
if embedding_model_name:
118+
raise ValueError(
119+
f"The index '{self._index_details.name}' uses Databricks-managed embeddings. "
120+
"Do not pass the `embedding_model_name` parameter when executing retriever calls."
121+
)
122+
return query, None
123+
124+
# For non-Databricks-managed embeddings
125+
from openai import OpenAI
126+
127+
oai_client = openai_client or OpenAI()
128+
if not oai_client.api_key:
129+
raise ValueError(
130+
"OpenAI API key is required to generate embeddings for retrieval queries."
131+
)
132+
if not embedding_model_name:
133+
raise ValueError(
134+
"The embedding model name is required for non-Databricks-managed "
135+
"embeddings Vector Search indexes in order to generate embeddings for retrieval queries."
136+
)
137+
138+
text = query if self.query_type and self.query_type.upper() == "HYBRID" else None
139+
vector = (
140+
oai_client.embeddings.create(input=query, model=embedding_model_name)
141+
.data[0]
142+
.embedding
143+
)
144+
if (
145+
index_embedding_dimension := self._index_details.embedding_vector_column.get(
146+
"embedding_dimension"
147+
)
148+
) and len(vector) != index_embedding_dimension:
149+
raise ValueError(
150+
f"Expected embedding dimension {index_embedding_dimension} but got {len(vector)}"
151+
)
152+
return text, vector
153+
154+
def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool:
155+
tool_call_arguments: Set[str] = set(json.loads(tool_call.function.arguments).keys())
156+
vs_index_arguments: Set[str] = set(
157+
self.tool["function"]["parameters"]["properties"].keys()
158+
)
159+
return (
160+
tool_call.function.name == self.tool["function"]["name"]
161+
and tool_call_arguments == vs_index_arguments
162+
)
163+
164+
if type(response) is not ChatCompletion:
165+
raise ValueError("response must be an instance of ChatCompletion")
166+
message = response.choices[choice_index].message
167+
llm_tool_calls = message.tool_calls
168+
function_calls = []
169+
if llm_tool_calls:
170+
for llm_tool_call in llm_tool_calls:
171+
# Only process tool calls that correspond to the self.tool VectorSearchRetrieverToolInput
172+
if not is_tool_call_for_index(llm_tool_call):
173+
continue
174+
175+
query_text, query_vector = get_query_text_vector(llm_tool_call)
176+
search_resp = self._index.similarity_search(
177+
columns=self.columns,
178+
query_text=query_text,
179+
query_vector=query_vector,
180+
filters=self.filters,
181+
num_results=self.num_results,
182+
query_type=self.query_type,
183+
)
184+
docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response(
185+
search_resp=search_resp,
186+
index_details=self._index_details,
187+
text_column=self.text_column,
188+
document_class=dict,
189+
)
190+
191+
function_call_result_message = {
192+
"role": "tool",
193+
"content": json.dumps({"content": docs_with_score}),
194+
"tool_call_id": llm_tool_call.id,
195+
}
196+
function_calls.append(function_call_result_message)
197+
assistant_message = message.to_dict()
198+
return [assistant_message, *function_calls]

0 commit comments

Comments
 (0)