1
+ import datetime
2
+ import os
3
+ import tempfile
1
4
from abc import ABC , abstractmethod
5
+ from pathlib import Path
2
6
from typing import Any , AsyncIterator , Callable , Dict , List , Optional , Union
3
7
4
8
import structlog
5
9
from fastapi import APIRouter , Request
6
10
from litellm import ModelResponse
7
11
from litellm .types .llms .openai import ChatCompletionRequest
8
12
13
+ from codegate .codegate_logging import setup_logging
9
14
from codegate .db .connection import DbRecorder
10
15
from codegate .pipeline .base import (
11
16
PipelineContext ,
19
24
from codegate .providers .normalizer .completion import CompletionNormalizer
20
25
from codegate .utils .utils import get_tool_name_from_messages
21
26
27
+ setup_logging ()
22
28
logger = structlog .get_logger ("codegate" )
23
29
30
+ TEMPDIR = None
31
+ if os .getenv ("CODEGATE_DUMP_DIR" ):
32
+ basedir = os .getenv ("CODEGATE_DUMP_DIR" )
33
+ TEMPDIR = tempfile .TemporaryDirectory (prefix = "codegate-" , dir = basedir , delete = False )
34
+
24
35
StreamGenerator = Callable [[AsyncIterator [Any ]], AsyncIterator [str ]]
25
36
26
37
@@ -205,6 +216,26 @@ async def _cleanup_after_streaming(
205
216
if context .sensitive :
206
217
context .sensitive .secure_cleanup ()
207
218
219
+ def _dump_request_response (self , prefix : str , data : Any ) -> None :
220
+ """Dump request or response data to a file if CODEGATE_DUMP_DIR is set"""
221
+ if not TEMPDIR :
222
+ return
223
+
224
+ ts = datetime .datetime .now ()
225
+ fname = (
226
+ Path (TEMPDIR .name )
227
+ / f"{ prefix } -{ self .provider_route_name } -{ ts .strftime ('%Y%m%dT%H%M%S%f' )} .json"
228
+ )
229
+
230
+ if isinstance (data , (dict , list )):
231
+ import json
232
+
233
+ with open (fname , "w" ) as f :
234
+ json .dump (data , f , indent = 2 )
235
+ else :
236
+ with open (fname , "w" ) as f :
237
+ f .write (str (data ))
238
+
208
239
async def complete (
209
240
self , data : Dict , api_key : Optional [str ], is_fim_request : bool
210
241
) -> Union [ModelResponse , AsyncIterator [ModelResponse ]]:
@@ -219,7 +250,11 @@ async def complete(
219
250
- Execute the completion and translate the response back to the
220
251
provider-specific format
221
252
"""
253
+ # Dump the incoming request
254
+ self ._dump_request_response ("request" , data )
222
255
normalized_request = self ._input_normalizer .normalize (data )
256
+ # Dump the normalized request
257
+ self ._dump_request_response ("normalized-request" , normalized_request )
223
258
streaming = normalized_request .get ("stream" , False )
224
259
input_pipeline_result = await self ._run_input_pipeline (
225
260
normalized_request ,
@@ -237,6 +272,8 @@ async def complete(
237
272
if is_fim_request :
238
273
provider_request = self ._fim_normalizer .denormalize (provider_request ) # type: ignore
239
274
275
+ self ._dump_request_response ("provider-request" , provider_request )
276
+
240
277
# Execute the completion and translate the response
241
278
# This gives us either a single response or a stream of responses
242
279
# based on the streaming flag
0 commit comments