-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VectorSearchRetrieverTool Openai integration #39
Merged
Merged
Changes from 12 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
2c816ed
Create VectorSearchRetrieverTool class for OpenAI
leonbi100 c27a665
Intermediate commit
leonbi100 db69782
Initial implementation
leonbi100 005a533
Working e2e delta sync index happy case
leonbi100 ff18663
Add unit tests and some validations
leonbi100 d27c68f
Undo line
leonbi100 f248245
Merge master
leonbi100 5cc9b46
Remove extra changes
leonbi100 88ae604
Fix embedding
leonbi100 bb9f109
Remove double field
leonbi100 9041017
Lint
leonbi100 e954e9e
Minor cleanup
leonbi100 e94cc9e
PR feedback
leonbi100 1fcc397
Lint
leonbi100 2b3b9d6
Rename tool call
leonbi100 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# test | ||
.pytest_cache/ | ||
mlruns/ | ||
|
||
# Byte-compiled files | ||
__pycache__ | ||
|
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
[project] | ||
name = "databricks-openai" | ||
version = "0.1.0" | ||
description = "Support for Databricks AI support with OpenAI" | ||
authors = [ | ||
{ name="Leon Bi", email="[email protected]" }, | ||
] | ||
readme = "README.md" | ||
license = { text="Apache-2.0" } | ||
requires-python = ">=3.9" | ||
dependencies = [ | ||
"databricks-vectorsearch>=0.40", | ||
"databricks-ai-bridge>=0.1.0", | ||
"openai>=1.46.1", | ||
] | ||
|
||
[project.optional-dependencies] | ||
dev = [ | ||
"pytest", | ||
"typing_extensions", | ||
"databricks-sdk>=0.34.0", | ||
"ruff==0.6.4", | ||
] | ||
|
||
integration = [ | ||
"pytest-timeout>=2.3.1", | ||
] | ||
|
||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[tool.hatch.build] | ||
include = [ | ||
"src/databricks_openai/*" | ||
] | ||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["src/databricks_openai"] | ||
|
||
[tool.ruff] | ||
line-length = 100 | ||
target-version = "py39" | ||
|
||
[tool.ruff.lint] | ||
select = [ | ||
# isort | ||
"I", | ||
# bugbear rules | ||
"B", | ||
# remove unused imports | ||
"F401", | ||
# bare except statements | ||
"E722", | ||
# print statements | ||
"T201", | ||
"T203", | ||
# misuse of typing.TYPE_CHECKING | ||
"TCH004", | ||
# import rules | ||
"TID251", | ||
# undefined-local-with-import-star | ||
"F403", | ||
] | ||
|
||
[tool.ruff.format] | ||
docstring-code-format = true | ||
docstring-code-line-length = 88 | ||
|
||
[tool.ruff.lint.pydocstyle] | ||
convention = "google" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from databricks_openai.vector_search_retriever_tool import VectorSearchRetrieverTool | ||
|
||
# Expose all integrations to users under databricks-langchain | ||
__all__ = [ | ||
"VectorSearchRetrieverTool", | ||
] |
199 changes: 199 additions & 0 deletions
199
integrations/openai/src/databricks_openai/vector_search_retriever_tool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,199 @@ | ||||||
import json | ||||||
from typing import Any, Dict, List, Optional, Tuple | ||||||
|
||||||
from databricks.vector_search.client import VectorSearchIndex | ||||||
from databricks_ai_bridge.utils.vector_search import ( | ||||||
IndexDetails, | ||||||
parse_vector_search_response, | ||||||
validate_and_get_return_columns, | ||||||
validate_and_get_text_column, | ||||||
) | ||||||
from databricks_ai_bridge.vector_search_retriever_tool import ( | ||||||
VectorSearchRetrieverToolInput, | ||||||
VectorSearchRetrieverToolMixin, | ||||||
) | ||||||
from pydantic import Field, PrivateAttr, model_validator | ||||||
|
||||||
from openai import OpenAI, pydantic_function_tool | ||||||
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall, ChatCompletionToolParam | ||||||
|
||||||
|
||||||
class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): | ||||||
""" | ||||||
A utility class to create a vector search-based retrieval tool for querying indexed embeddings. | ||||||
This class integrates with Databricks Vector Search and provides a convenient interface | ||||||
for tool calling using the OpenAI SDK. | ||||||
|
||||||
Example: | ||||||
dbvs_tool = VectorSearchRetrieverTool("index_name") | ||||||
tools = [dbvs_tool.tool, ...] | ||||||
response = openai.chat.completions.create( | ||||||
model="gpt-4o", | ||||||
messages=initial_messages, | ||||||
tools=tools, | ||||||
) | ||||||
retriever_call_message = dbvs_tool.execute_retriever_calls(response) | ||||||
|
||||||
### If needed, execute potential remaining tool calls here ### | ||||||
remaining_tool_call_messages = execute_remaining_tool_calls(response) | ||||||
|
||||||
final_response = openai.chat.completions.create( | ||||||
model="gpt-4o", | ||||||
messages=initial_messages + retriever_call_message + remaining_tool_call_messages, | ||||||
tools=tools, | ||||||
) | ||||||
final_response.choices[0].message.content | ||||||
""" | ||||||
|
||||||
text_column: Optional[str] = Field( | ||||||
None, | ||||||
description="The name of the text column to use for the embeddings. " | ||||||
"Required for direct-access index or delta-sync index with " | ||||||
"self-managed embeddings. Used for direct access indexes or " | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: do we need the description to be duplicated (required for ... used for ...)? |
||||||
"delta-sync indexes with self-managed embeddings", | ||||||
) | ||||||
|
||||||
tool: ChatCompletionToolParam = Field( | ||||||
None, description="The tool input used in the OpenAI chat completion SDK" | ||||||
) | ||||||
_index: VectorSearchIndex = PrivateAttr() | ||||||
_index_details: IndexDetails = PrivateAttr() | ||||||
|
||||||
@model_validator(mode="after") | ||||||
def _validate_tool_inputs(self): | ||||||
from databricks.vector_search.client import ( | ||||||
VectorSearchClient, # import here so we can mock in tests | ||||||
) | ||||||
|
||||||
self._index = VectorSearchClient().get_index(index_name=self.index_name) | ||||||
self._index_details = IndexDetails(self._index) | ||||||
self.text_column = validate_and_get_text_column(self.text_column, self._index_details) | ||||||
self.columns = validate_and_get_return_columns( | ||||||
self.columns or [], self.text_column, self._index_details | ||||||
) | ||||||
|
||||||
# OpenAI tool names must match the pattern '^[a-zA-Z0-9_-]+$'." | ||||||
# The '.' from the index name are not allowed | ||||||
def rewrite_index_name(index_name: str): | ||||||
return index_name.replace(".", "_") | ||||||
|
||||||
self.tool = pydantic_function_tool( | ||||||
VectorSearchRetrieverToolInput, | ||||||
name=self.tool_name or rewrite_index_name(self.index_name), | ||||||
description=self.tool_description | ||||||
or self._get_default_tool_description(self._index_details), | ||||||
) | ||||||
return self | ||||||
|
||||||
def execute_retriever_calls( | ||||||
self, | ||||||
response: ChatCompletion, | ||||||
choice_index: int = 0, | ||||||
embedding_model_name: str = None, | ||||||
openai_client: OpenAI = None, | ||||||
) -> List[Dict[str, Any]]: | ||||||
""" | ||||||
Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the | ||||||
self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into toll call messages. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: typo
Suggested change
|
||||||
|
||||||
Args: | ||||||
response: The chat completion response object returned by the OpenAI API. | ||||||
choice_index: The index of the choice to process. Defaults to 0. Note that multiple | ||||||
choices are not supported yet. | ||||||
embedding_model_name: The name of the embedding model to use for embedding the query text. | ||||||
Required for direct access indexes or delta-sync indexes with self-managed embeddings. | ||||||
openai_client: The OpenAI client object used to generate embeddings for retrieval queries. If not provided, | ||||||
the default OpenAI client in the current environment will be used. | ||||||
|
||||||
Returns: | ||||||
A list of messages containing the assistant message and the retriever call results | ||||||
that correspond to the self.tool VectorSearchRetrieverToolInput. | ||||||
""" | ||||||
|
||||||
def get_query_text_vector( | ||||||
tool_call: ChatCompletionMessageToolCall, | ||||||
) -> Tuple[Optional[str], Optional[List[float]]]: | ||||||
query = json.loads(tool_call.function.arguments)["query"] | ||||||
if self._index_details.is_databricks_managed_embeddings(): | ||||||
if embedding_model_name: | ||||||
raise ValueError( | ||||||
f"The index '{self._index_details.name}' uses Databricks-managed embeddings. " | ||||||
"Do not pass the `embedding_model_name` parameter when executing retriever calls." | ||||||
) | ||||||
return query, None | ||||||
|
||||||
# For non-Databricks-managed embeddings | ||||||
from openai import OpenAI | ||||||
|
||||||
oai_client = openai_client or OpenAI() | ||||||
if not oai_client.api_key: | ||||||
raise ValueError( | ||||||
"OpenAI API key is required to generate embeddings for retrieval queries." | ||||||
) | ||||||
if not embedding_model_name: | ||||||
raise ValueError( | ||||||
"The embedding model name is required for non-Databricks-managed " | ||||||
"embeddings Vector Search indexes in order to generate embeddings for retrieval queries." | ||||||
) | ||||||
|
||||||
text = query if self.query_type and self.query_type.upper() == "HYBRID" else None | ||||||
vector = ( | ||||||
oai_client.embeddings.create(input=query, model=embedding_model_name) | ||||||
.data[0] | ||||||
.embedding | ||||||
) | ||||||
if ( | ||||||
index_embedding_dimension := self._index_details.embedding_vector_column.get( | ||||||
"embedding_dimension" | ||||||
) | ||||||
) and len(vector) != index_embedding_dimension: | ||||||
raise ValueError( | ||||||
f"Expected embedding dimension {index_embedding_dimension} but got {len(vector)}" | ||||||
) | ||||||
return text, vector | ||||||
|
||||||
def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool: | ||||||
tool_call_arguments: Set[str] = set(json.loads(tool_call.function.arguments).keys()) | ||||||
vs_index_arguments: Set[str] = set( | ||||||
self.tool["function"]["parameters"]["properties"].keys() | ||||||
) | ||||||
return ( | ||||||
tool_call.function.name == self.tool["function"]["name"] | ||||||
and tool_call_arguments == vs_index_arguments | ||||||
) | ||||||
|
||||||
if type(response) is not ChatCompletion: | ||||||
raise ValueError("response must be an instance of ChatCompletion") | ||||||
message = response.choices[choice_index].message | ||||||
llm_tool_calls = message.tool_calls | ||||||
function_calls = [] | ||||||
if llm_tool_calls: | ||||||
for llm_tool_call in llm_tool_calls: | ||||||
# Only process tool calls that correspond to the self.tool VectorSearchRetrieverToolInput | ||||||
if not is_tool_call_for_index(llm_tool_call): | ||||||
continue | ||||||
|
||||||
query_text, query_vector = get_query_text_vector(llm_tool_call) | ||||||
search_resp = self._index.similarity_search( | ||||||
columns=self.columns, | ||||||
query_text=query_text, | ||||||
query_vector=query_vector, | ||||||
filters=self.filters, | ||||||
num_results=self.num_results, | ||||||
query_type=self.query_type, | ||||||
) | ||||||
docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response( | ||||||
search_resp=search_resp, | ||||||
index_details=self._index_details, | ||||||
text_column=self.text_column, | ||||||
document_class=dict, | ||||||
) | ||||||
|
||||||
function_call_result_message = { | ||||||
"role": "tool", | ||||||
"content": json.dumps({"content": docs_with_score}), | ||||||
"tool_call_id": llm_tool_call.id, | ||||||
} | ||||||
function_calls.append(function_call_result_message) | ||||||
assistant_message = message.to_dict() | ||||||
return [assistant_message, *function_calls] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming this is supposed to be databricks-openai?