Skip to content

Commit fc00a53

Browse files
authored
Optionally dump requests in providers (#847)
This is useful to create test data and reproduce issues.
1 parent a107f47 commit fc00a53

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

src/codegate/providers/base.py

+37
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import datetime
2+
import os
3+
import tempfile
14
from abc import ABC, abstractmethod
5+
from pathlib import Path
26
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
37

48
import structlog
59
from fastapi import APIRouter, Request
610
from litellm import ModelResponse
711
from litellm.types.llms.openai import ChatCompletionRequest
812

13+
from codegate.codegate_logging import setup_logging
914
from codegate.db.connection import DbRecorder
1015
from codegate.pipeline.base import (
1116
PipelineContext,
@@ -19,8 +24,14 @@
1924
from codegate.providers.normalizer.completion import CompletionNormalizer
2025
from codegate.utils.utils import get_tool_name_from_messages
2126

27+
setup_logging()
2228
logger = structlog.get_logger("codegate")
2329

30+
TEMPDIR = None
31+
if os.getenv("CODEGATE_DUMP_DIR"):
32+
basedir = os.getenv("CODEGATE_DUMP_DIR")
33+
TEMPDIR = tempfile.TemporaryDirectory(prefix="codegate-", dir=basedir, delete=False)
34+
2435
StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]
2536

2637

@@ -205,6 +216,26 @@ async def _cleanup_after_streaming(
205216
if context.sensitive:
206217
context.sensitive.secure_cleanup()
207218

219+
def _dump_request_response(self, prefix: str, data: Any) -> None:
220+
"""Dump request or response data to a file if CODEGATE_DUMP_DIR is set"""
221+
if not TEMPDIR:
222+
return
223+
224+
ts = datetime.datetime.now()
225+
fname = (
226+
Path(TEMPDIR.name)
227+
/ f"{prefix}-{self.provider_route_name}-{ts.strftime('%Y%m%dT%H%M%S%f')}.json"
228+
)
229+
230+
if isinstance(data, (dict, list)):
231+
import json
232+
233+
with open(fname, "w") as f:
234+
json.dump(data, f, indent=2)
235+
else:
236+
with open(fname, "w") as f:
237+
f.write(str(data))
238+
208239
async def complete(
209240
self, data: Dict, api_key: Optional[str], is_fim_request: bool
210241
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
@@ -219,7 +250,11 @@ async def complete(
219250
- Execute the completion and translate the response back to the
220251
provider-specific format
221252
"""
253+
# Dump the incoming request
254+
self._dump_request_response("request", data)
222255
normalized_request = self._input_normalizer.normalize(data)
256+
# Dump the normalized request
257+
self._dump_request_response("normalized-request", normalized_request)
223258
streaming = normalized_request.get("stream", False)
224259
input_pipeline_result = await self._run_input_pipeline(
225260
normalized_request,
@@ -237,6 +272,8 @@ async def complete(
237272
if is_fim_request:
238273
provider_request = self._fim_normalizer.denormalize(provider_request) # type: ignore
239274

275+
self._dump_request_response("provider-request", provider_request)
276+
240277
# Execute the completion and translate the response
241278
# This gives us either a single response or a stream of responses
242279
# based on the streaming flag

0 commit comments

Comments
 (0)