9393else :
9494 palm = PaLM (api_key = API_KEY , api_endpoint = PALM_API_ENDPOINT )
9595
96- embedding_function_gemini_retrieval = embedding_functions .GoogleGenerativeAiEmbeddingFunction (
97- api_key = API_KEY , model_name = "models/embedding-001" ,
98- task_type = "RETRIEVAL_QUERY" )
96+ embedding_function_gemini_retrieval = (
97+ embedding_functions .GoogleGenerativeAiEmbeddingFunction (
98+ api_key = API_KEY , model_name = "models/embedding-001" , task_type = "RETRIEVAL_QUERY"
99+ )
100+ )
101+
99102
100103class DocsAgent :
101104 """DocsAgent class"""
@@ -107,8 +110,9 @@ def __init__(self):
107110 )
108111 self .chroma = Chroma (LOCAL_VECTOR_DB_DIR )
109112 self .collection = self .chroma .get_collection (
110- COLLECTION_NAME , embedding_model = EMBEDDING_MODEL ,
111- embedding_function = embedding_function_gemini_retrieval
113+ COLLECTION_NAME ,
114+ embedding_model = EMBEDDING_MODEL ,
115+ embedding_function = embedding_function_gemini_retrieval ,
112116 )
113117 # Update PaLM's custom prompt strings
114118 self .prompt_condition = CONDITION_TEXT
@@ -121,12 +125,21 @@ def __init__(self):
121125 self .is_aqa_used = IS_AQA_USED
122126 self .db_type = DB_TYPE
123127 # AQA model setup
124- self .generative_service_client = glm . GenerativeServiceClient ()
125- self .retriever_service_client = glm . RetrieverServiceClient ()
126- self .permission_service_client = glm . PermissionServiceClient ()
128+ self .generative_service_client = {}
129+ self .retriever_service_client = {}
130+ self .permission_service_client = {}
127131 self .corpus_display = PRODUCT_NAME + " documentation"
128132 self .corpus_name = "corpora/" + PRODUCT_NAME .lower ().replace (" " , "-" )
129133 self .aqa_response_buffer = ""
134+ self .set_up_aqa_model_environment ()
135+
136+ # Set up the AQA model environment
137+ def set_up_aqa_model_environment (self ):
138+ if IS_AQA_USED == "YES" :
139+ self .generative_service_client = glm .GenerativeServiceClient ()
140+ self .retriever_service_client = glm .RetrieverServiceClient ()
141+ self .permission_service_client = glm .PermissionServiceClient ()
142+ return
130143
131144 # Use this method for talking to a PaLM text model
132145 def ask_text_model_with_context (self , context , question ):
@@ -203,7 +216,11 @@ def ask_aqa_model_using_local_vector_store(self, question):
203216 elif LOG_LEVEL == "DEBUG" :
204217 self .print_the_prompt (verbose_prompt )
205218 print (aqa_response )
206- return aqa_response .answer .content .parts [0 ].text
219+ try :
220+ return aqa_response .answer .content .parts [0 ].text
221+ except :
222+ self .aqa_response_buffer = ""
223+ return self .model_error_message
207224
208225 # Use this method for talking to Gemini's AQA model using a corpus
209226 def ask_aqa_model_using_corpora (self , question ):
@@ -243,7 +260,11 @@ def ask_aqa_model_using_corpora(self, question):
243260 self .print_the_prompt (verbose_prompt )
244261 elif LOG_LEVEL == "DEBUG" :
245262 print (aqa_response )
246- return aqa_response .answer .content .parts [0 ].text
263+ try :
264+ return aqa_response .answer .content .parts [0 ].text
265+ except :
266+ self .aqa_response_buffer = ""
267+ return self .model_error_message
247268
248269 def ask_aqa_model (self , question ):
249270 response = ""
@@ -323,6 +344,27 @@ def check_if_aqa_is_used(self):
323344 def get_saved_aqa_response_json (self ):
324345 return self .aqa_response_buffer
325346
347+ # Retrieve the URL metadata from the AQA model's response
348+ def get_aqa_response_url (self ):
349+ url = ""
350+ try :
351+ # Get the metadata from the first attributed passages for the source
352+ chunk_resource_name = (
353+ self .aqa_response_buffer .answer .grounding_attributions [
354+ 0
355+ ].source_id .semantic_retriever_chunk .chunk
356+ )
357+ get_chunk_response = self .retriever_service_client .get_chunk (
358+ name = chunk_resource_name
359+ )
360+ metadata = get_chunk_response .custom_metadata
361+ for m in metadata :
362+ if m .key == "url" :
363+ url = m .string_value
364+ except :
365+ url = "URL unknown"
366+ return url
367+
326368 # Print the prompt on the terminal for debugging
327369 def print_the_prompt (self , prompt ):
328370 print ("#########################################" )
0 commit comments