|
39 | 39 | generate_tts__after,
|
40 | 40 | )
|
41 | 41 | from ..llm_call.utils import (
|
| 42 | + LLMCallException, |
42 | 43 | append_message_content_to_chat_history,
|
43 | 44 | get_chat_response,
|
44 | 45 | init_chat_history,
|
@@ -131,21 +132,32 @@ async def chat(
|
131 | 132 | QueryResponse | JSONResponse
|
132 | 133 | The query response object or an appropriate JSON response.
|
133 | 134 | """
|
| 135 | + try: |
| 136 | + # 1. |
| 137 | + user_query = await init_user_query_and_chat_histories( |
| 138 | + redis_client=request.app.state.redis, |
| 139 | + reset_chat_history=reset_chat_history, |
| 140 | + user_query=user_query, |
| 141 | + ) |
134 | 142 |
|
135 |
| - # 1. |
136 |
| - user_query = await init_user_query_and_chat_histories( |
137 |
| - redis_client=request.app.state.redis, |
138 |
| - reset_chat_history=reset_chat_history, |
139 |
| - user_query=user_query, |
140 |
| - ) |
| 143 | + # 2 |
141 | 144 |
|
142 |
| - # 2. |
143 |
| - return await search( |
144 |
| - user_query=user_query, |
145 |
| - request=request, |
146 |
| - asession=asession, |
147 |
| - workspace_db=workspace_db, |
148 |
| - ) |
| 145 | + response = await search( |
| 146 | + user_query=user_query, |
| 147 | + request=request, |
| 148 | + asession=asession, |
| 149 | + workspace_db=workspace_db, |
| 150 | + ) |
| 151 | + return response |
| 152 | + except LLMCallException: |
| 153 | + return JSONResponse( |
| 154 | + status_code=status.HTTP_502_BAD_GATEWAY, |
| 155 | + content={ |
| 156 | + "error_message": ( |
| 157 | + "LLM call returned an error: Please check LLM configuration" |
| 158 | + ) |
| 159 | + }, |
| 160 | + ) |
149 | 161 |
|
150 | 162 |
|
151 | 163 | @router.post(
|
@@ -186,63 +198,74 @@ async def search(
|
186 | 198 | QueryResponse | JSONResponse
|
187 | 199 | The query response object or an appropriate JSON response.
|
188 | 200 | """
|
| 201 | + try: |
| 202 | + workspace_id = workspace_db.workspace_id |
| 203 | + user_query_db, user_query_refined_template, response_template = ( |
| 204 | + await get_user_query_and_response( |
| 205 | + asession=asession, |
| 206 | + generate_tts=False, |
| 207 | + user_query=user_query, |
| 208 | + workspace_id=workspace_id, |
| 209 | + ) |
| 210 | + ) |
| 211 | + assert isinstance(user_query_db, QueryDB) |
189 | 212 |
|
190 |
| - workspace_id = workspace_db.workspace_id |
191 |
| - user_query_db, user_query_refined_template, response_template = ( |
192 |
| - await get_user_query_and_response( |
| 213 | + response = await get_search_response( |
193 | 214 | asession=asession,
|
194 |
| - generate_tts=False, |
195 |
| - user_query=user_query, |
| 215 | + exclude_archived=True, |
| 216 | + n_similar=int(N_TOP_CONTENT), |
| 217 | + n_to_crossencoder=int(N_TOP_CONTENT_TO_CROSSENCODER), |
| 218 | + query_refined=user_query_refined_template, |
| 219 | + request=request, |
| 220 | + response=response_template, |
196 | 221 | workspace_id=workspace_id,
|
197 | 222 | )
|
198 |
| - ) |
199 |
| - assert isinstance(user_query_db, QueryDB) |
200 | 223 |
|
201 |
| - response = await get_search_response( |
202 |
| - asession=asession, |
203 |
| - exclude_archived=True, |
204 |
| - n_similar=int(N_TOP_CONTENT), |
205 |
| - n_to_crossencoder=int(N_TOP_CONTENT_TO_CROSSENCODER), |
206 |
| - query_refined=user_query_refined_template, |
207 |
| - request=request, |
208 |
| - response=response_template, |
209 |
| - workspace_id=workspace_id, |
210 |
| - ) |
| 224 | + if user_query.generate_llm_response: |
| 225 | + response = await get_generation_response( |
| 226 | + query_refined=user_query_refined_template, response=response |
| 227 | + ) |
211 | 228 |
|
212 |
| - if user_query.generate_llm_response: |
213 |
| - response = await get_generation_response( |
214 |
| - query_refined=user_query_refined_template, response=response |
| 229 | + await save_query_response_to_db( |
| 230 | + asession=asession, |
| 231 | + response=response, |
| 232 | + user_query_db=user_query_db, |
| 233 | + workspace_id=workspace_id, |
| 234 | + ) |
| 235 | + await increment_query_count( |
| 236 | + asession=asession, |
| 237 | + contents=response.search_results, |
| 238 | + workspace_id=workspace_id, |
| 239 | + ) |
| 240 | + await save_content_for_query_to_db( |
| 241 | + asession=asession, |
| 242 | + contents=response.search_results, |
| 243 | + query_id=response.query_id, |
| 244 | + session_id=user_query.session_id, |
| 245 | + workspace_id=workspace_id, |
215 | 246 | )
|
216 | 247 |
|
217 |
| - await save_query_response_to_db( |
218 |
| - asession=asession, |
219 |
| - response=response, |
220 |
| - user_query_db=user_query_db, |
221 |
| - workspace_id=workspace_id, |
222 |
| - ) |
223 |
| - await increment_query_count( |
224 |
| - asession=asession, contents=response.search_results, workspace_id=workspace_id |
225 |
| - ) |
226 |
| - await save_content_for_query_to_db( |
227 |
| - asession=asession, |
228 |
| - contents=response.search_results, |
229 |
| - query_id=response.query_id, |
230 |
| - session_id=user_query.session_id, |
231 |
| - workspace_id=workspace_id, |
232 |
| - ) |
| 248 | + if isinstance(response, QueryResponseError): |
| 249 | + return JSONResponse( |
| 250 | + status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump() |
| 251 | + ) |
| 252 | + |
| 253 | + if isinstance(response, QueryResponse): |
| 254 | + return response |
233 | 255 |
|
234 |
| - if isinstance(response, QueryResponseError): |
235 | 256 | return JSONResponse(
|
236 |
| - status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump() |
| 257 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 258 | + content={"error_message": "Internal server error"}, |
| 259 | + ) |
| 260 | + except LLMCallException: |
| 261 | + return JSONResponse( |
| 262 | + status_code=status.HTTP_502_BAD_GATEWAY, |
| 263 | + content={ |
| 264 | + "error_message": ( |
| 265 | + "LLM call returned an error: Please check LLM configuration" |
| 266 | + ) |
| 267 | + }, |
237 | 268 | )
|
238 |
| - |
239 |
| - if isinstance(response, QueryResponse): |
240 |
| - return response |
241 |
| - |
242 |
| - return JSONResponse( |
243 |
| - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
244 |
| - content={"message": "Internal server error"}, |
245 |
| - ) |
246 | 269 |
|
247 | 270 |
|
248 | 271 | @router.post(
|
|
0 commit comments