|
| 1 | +import json |
| 2 | +from typing import Any, Dict, List, Optional, Tuple |
| 3 | + |
| 4 | +from databricks.vector_search.client import VectorSearchIndex |
| 5 | +from databricks_ai_bridge.utils.vector_search import ( |
| 6 | + IndexDetails, |
| 7 | + parse_vector_search_response, |
| 8 | + validate_and_get_return_columns, |
| 9 | + validate_and_get_text_column, |
| 10 | +) |
| 11 | +from databricks_ai_bridge.vector_search_retriever_tool import ( |
| 12 | + VectorSearchRetrieverToolInput, |
| 13 | + VectorSearchRetrieverToolMixin, |
| 14 | +) |
| 15 | +from pydantic import Field, PrivateAttr, model_validator |
| 16 | + |
| 17 | +from openai import OpenAI, pydantic_function_tool |
| 18 | +from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall, ChatCompletionToolParam |
| 19 | + |
| 20 | + |
| 21 | +class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): |
| 22 | + """ |
| 23 | + A utility class to create a vector search-based retrieval tool for querying indexed embeddings. |
| 24 | + This class integrates with Databricks Vector Search and provides a convenient interface |
| 25 | + for tool calling using the OpenAI SDK. |
| 26 | +
|
| 27 | + Example: |
| 28 | + dbvs_tool = VectorSearchRetrieverTool("index_name") |
| 29 | + tools = [dbvs_tool.tool, ...] |
| 30 | + response = openai.chat.completions.create( |
| 31 | + model="gpt-4o", |
| 32 | + messages=initial_messages, |
| 33 | + tools=tools, |
| 34 | + ) |
| 35 | + retriever_call_message = dbvs_tool.execute_calls(response) |
| 36 | +
|
| 37 | + ### If needed, execute potential remaining tool calls here ### |
| 38 | + remaining_tool_call_messages = execute_remaining_tool_calls(response) |
| 39 | +
|
| 40 | + final_response = openai.chat.completions.create( |
| 41 | + model="gpt-4o", |
| 42 | + messages=initial_messages + retriever_call_message + remaining_tool_call_messages, |
| 43 | + tools=tools, |
| 44 | + ) |
| 45 | + final_response.choices[0].message.content |
| 46 | + """ |
| 47 | + |
| 48 | + text_column: Optional[str] = Field( |
| 49 | + None, |
| 50 | + description="The name of the text column to use for the embeddings. " |
| 51 | + "Required for direct-access index or delta-sync index with " |
| 52 | + "self-managed embeddings.", |
| 53 | + ) |
| 54 | + |
| 55 | + tool: ChatCompletionToolParam = Field( |
| 56 | + None, description="The tool input used in the OpenAI chat completion SDK" |
| 57 | + ) |
| 58 | + _index: VectorSearchIndex = PrivateAttr() |
| 59 | + _index_details: IndexDetails = PrivateAttr() |
| 60 | + |
| 61 | + @model_validator(mode="after") |
| 62 | + def _validate_tool_inputs(self): |
| 63 | + from databricks.vector_search.client import ( |
| 64 | + VectorSearchClient, # import here so we can mock in tests |
| 65 | + ) |
| 66 | + |
| 67 | + self._index = VectorSearchClient().get_index(index_name=self.index_name) |
| 68 | + self._index_details = IndexDetails(self._index) |
| 69 | + self.text_column = validate_and_get_text_column(self.text_column, self._index_details) |
| 70 | + self.columns = validate_and_get_return_columns( |
| 71 | + self.columns or [], self.text_column, self._index_details |
| 72 | + ) |
| 73 | + |
| 74 | + # OpenAI tool names must match the pattern '^[a-zA-Z0-9_-]+$'." |
| 75 | + # The '.' from the index name are not allowed |
| 76 | + def rewrite_index_name(index_name: str): |
| 77 | + return index_name.replace(".", "_") |
| 78 | + |
| 79 | + self.tool = pydantic_function_tool( |
| 80 | + VectorSearchRetrieverToolInput, |
| 81 | + name=self.tool_name or rewrite_index_name(self.index_name), |
| 82 | + description=self.tool_description |
| 83 | + or self._get_default_tool_description(self._index_details), |
| 84 | + ) |
| 85 | + return self |
| 86 | + |
| 87 | + def execute_calls( |
| 88 | + self, |
| 89 | + response: ChatCompletion, |
| 90 | + choice_index: int = 0, |
| 91 | + embedding_model_name: str = None, |
| 92 | + openai_client: OpenAI = None, |
| 93 | + ) -> List[Dict[str, Any]]: |
| 94 | + """ |
| 95 | + Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the |
| 96 | + self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages. |
| 97 | +
|
| 98 | + Args: |
| 99 | + response: The chat completion response object returned by the OpenAI API. |
| 100 | + choice_index: The index of the choice to process. Defaults to 0. Note that multiple |
| 101 | + choices are not supported yet. |
| 102 | + embedding_model_name: The name of the embedding model to use for embedding the query text. |
| 103 | + Required for direct access indexes or delta-sync indexes with self-managed embeddings. |
| 104 | + openai_client: The OpenAI client object used to generate embeddings for retrieval queries. If not provided, |
| 105 | + the default OpenAI client in the current environment will be used. |
| 106 | +
|
| 107 | + Returns: |
| 108 | + A list of messages containing the assistant message and the retriever call results |
| 109 | + that correspond to the self.tool VectorSearchRetrieverToolInput. |
| 110 | + """ |
| 111 | + |
| 112 | + def get_query_text_vector( |
| 113 | + tool_call: ChatCompletionMessageToolCall, |
| 114 | + ) -> Tuple[Optional[str], Optional[List[float]]]: |
| 115 | + query = json.loads(tool_call.function.arguments)["query"] |
| 116 | + if self._index_details.is_databricks_managed_embeddings(): |
| 117 | + if embedding_model_name: |
| 118 | + raise ValueError( |
| 119 | + f"The index '{self._index_details.name}' uses Databricks-managed embeddings. " |
| 120 | + "Do not pass the `embedding_model_name` parameter when executing retriever calls." |
| 121 | + ) |
| 122 | + return query, None |
| 123 | + |
| 124 | + # For non-Databricks-managed embeddings |
| 125 | + from openai import OpenAI |
| 126 | + |
| 127 | + oai_client = openai_client or OpenAI() |
| 128 | + if not oai_client.api_key: |
| 129 | + raise ValueError( |
| 130 | + "OpenAI API key is required to generate embeddings for retrieval queries." |
| 131 | + ) |
| 132 | + if not embedding_model_name: |
| 133 | + raise ValueError( |
| 134 | + "The embedding model name is required for non-Databricks-managed " |
| 135 | + "embeddings Vector Search indexes in order to generate embeddings for retrieval queries." |
| 136 | + ) |
| 137 | + |
| 138 | + text = query if self.query_type and self.query_type.upper() == "HYBRID" else None |
| 139 | + vector = ( |
| 140 | + oai_client.embeddings.create(input=query, model=embedding_model_name) |
| 141 | + .data[0] |
| 142 | + .embedding |
| 143 | + ) |
| 144 | + if ( |
| 145 | + index_embedding_dimension := self._index_details.embedding_vector_column.get( |
| 146 | + "embedding_dimension" |
| 147 | + ) |
| 148 | + ) and len(vector) != index_embedding_dimension: |
| 149 | + raise ValueError( |
| 150 | + f"Expected embedding dimension {index_embedding_dimension} but got {len(vector)}" |
| 151 | + ) |
| 152 | + return text, vector |
| 153 | + |
| 154 | + def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool: |
| 155 | + tool_call_arguments: Set[str] = set(json.loads(tool_call.function.arguments).keys()) |
| 156 | + vs_index_arguments: Set[str] = set( |
| 157 | + self.tool["function"]["parameters"]["properties"].keys() |
| 158 | + ) |
| 159 | + return ( |
| 160 | + tool_call.function.name == self.tool["function"]["name"] |
| 161 | + and tool_call_arguments == vs_index_arguments |
| 162 | + ) |
| 163 | + |
| 164 | + if type(response) is not ChatCompletion: |
| 165 | + raise ValueError("response must be an instance of ChatCompletion") |
| 166 | + message = response.choices[choice_index].message |
| 167 | + llm_tool_calls = message.tool_calls |
| 168 | + function_calls = [] |
| 169 | + if llm_tool_calls: |
| 170 | + for llm_tool_call in llm_tool_calls: |
| 171 | + # Only process tool calls that correspond to the self.tool VectorSearchRetrieverToolInput |
| 172 | + if not is_tool_call_for_index(llm_tool_call): |
| 173 | + continue |
| 174 | + |
| 175 | + query_text, query_vector = get_query_text_vector(llm_tool_call) |
| 176 | + search_resp = self._index.similarity_search( |
| 177 | + columns=self.columns, |
| 178 | + query_text=query_text, |
| 179 | + query_vector=query_vector, |
| 180 | + filters=self.filters, |
| 181 | + num_results=self.num_results, |
| 182 | + query_type=self.query_type, |
| 183 | + ) |
| 184 | + docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response( |
| 185 | + search_resp=search_resp, |
| 186 | + index_details=self._index_details, |
| 187 | + text_column=self.text_column, |
| 188 | + document_class=dict, |
| 189 | + ) |
| 190 | + |
| 191 | + function_call_result_message = { |
| 192 | + "role": "tool", |
| 193 | + "content": json.dumps({"content": docs_with_score}), |
| 194 | + "tool_call_id": llm_tool_call.id, |
| 195 | + } |
| 196 | + function_calls.append(function_call_result_message) |
| 197 | + assistant_message = message.to_dict() |
| 198 | + return [assistant_message, *function_calls] |
0 commit comments