diff --git a/gai-backend/inference_adapters.py b/gai-backend/inference_adapters.py new file mode 100644 index 000000000..df713b7de --- /dev/null +++ b/gai-backend/inference_adapters.py @@ -0,0 +1,143 @@ +from typing import Dict, Any, Tuple, Optional +import uuid +import logging +from inference_models import ChatCompletionRequest, ChatCompletion, ChatChoice, Message, Usage, InferenceAPIError +from datetime import datetime + +logger = logging.getLogger(__name__) + +class ModelAdapter: + @staticmethod + def parse_openai_response(response: Dict[str, Any], model: str, response_id: str) -> ChatCompletion: + choice = response['choices'][0] + return ChatCompletion( + id=response_id, + model=model, + choices=[ + ChatChoice( + index=0, + message=Message( + role="assistant", + content=choice['message']['content'], + function_call=choice['message'].get('function_call') + ), + finish_reason=choice.get('finish_reason') + ) + ], + usage=Usage( + prompt_tokens=response['usage']['prompt_tokens'], + completion_tokens=response['usage']['completion_tokens'], + total_tokens=response['usage']['total_tokens'] + ) + ) + + @staticmethod + def parse_anthropic_response(response: Dict[str, Any], model: str, response_id: str) -> ChatCompletion: + stop_reason_map = { + "max_tokens": "length", + "stop_sequence": "stop" + } + + # Add debug logging + logger.info(f"Parsing Anthropic response: {response}") + + chat_completion = ChatCompletion( + id=response_id, + model=model, + choices=[ + ChatChoice( + index=0, + message=Message( + role="assistant", + content=response['content'][0]['text'] + ), + finish_reason=stop_reason_map.get(response.get('stop_reason'), "stop") + ) + ], + usage=Usage( + prompt_tokens=response['usage']['input_tokens'], + completion_tokens=response['usage']['output_tokens'], + total_tokens=response['usage']['input_tokens'] + response['usage']['output_tokens'] + ) + ) + + # Log the parsed response + logger.info(f"Parsed Anthropic response: {chat_completion.dict()}") + return chat_completion + + @classmethod + def parse_response(cls, api_type: str, response: Dict[str, Any], model: str, request_id: Optional[str] = None) -> ChatCompletion: + try: + # Generate a request ID if none was provided + rid = request_id or str(uuid.uuid4()) + response_id = f"chatcmpl-{rid}" + + if api_type == 'openai': + return cls.parse_openai_response(response, model, response_id) + elif api_type == 'anthropic': + return cls.parse_anthropic_response(response, model, response_id) + elif api_type == 'openrouter': + return cls.parse_openai_response(response, model, response_id) + else: + raise InferenceAPIError(500, f"Unsupported API type: {api_type}") + except KeyError as e: + logger.error(f"Failed to parse {api_type} response: {e}") + logger.debug(f"Raw response: {response}") + raise InferenceAPIError(502, f"Invalid {api_type} response format") + + @staticmethod + def prepare_openai_request(model_config: Dict[str, Any], api_key: str, request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + data = { + 'model': model_config['id'], + 'messages': [msg.dict(exclude_none=True) for msg in request.messages] + } + + # Add OpenAI parameters + for param in ['temperature', 'top_p', 'frequency_penalty', 'presence_penalty', 'max_tokens']: + if (value := getattr(request, param)) is not None: + data[param] = value + + return data, headers + + @staticmethod + def prepare_anthropic_request(model_config: Dict[str, Any], api_key: str, request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]: + headers = { + "Content-Type": "application/json", + "x-api-key": api_key, + "anthropic-version": "2023-06-01" + } + + system_message = next((msg.content for msg in request.messages if msg.role == "system"), None) + conversation = [msg for msg in request.messages if msg.role != "system"] + + data = { + 'model': model_config['id'], + 'messages': [{'role': msg.role, 'content': msg.content} for msg in conversation], + 'max_tokens': request.max_tokens or 4096 + } + + if system_message: + data['system'] = system_message + + return data, headers + + @classmethod + def prepare_request(cls, endpoint_config: Dict[str, Any], model_config: Dict[str, Any], request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]: + api_key = endpoint_config.get('api_key') + if not api_key: + raise InferenceAPIError(500, "Backend authentication not configured") + + api_type = endpoint_config['api_type'] + if api_type == 'openai': + return cls.prepare_openai_request(model_config, api_key, request) + elif api_type == 'anthropic': + return cls.prepare_anthropic_request(model_config, api_key, request) + elif api_type == 'openrouter': + return cls.prepare_openai_request(model_config, api_key, request) + else: + raise InferenceAPIError(500, f"Unsupported API type: {api_type}") diff --git a/gai-backend/inference_api.py b/gai-backend/inference_api.py index 655479d5b..e1f4323c5 100644 --- a/gai-backend/inference_api.py +++ b/gai-backend/inference_api.py @@ -1,7 +1,8 @@ -from fastapi import FastAPI, HTTPException, Depends +from contextlib import asynccontextmanager +from fastapi import FastAPI, HTTPException, Depends, Request from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from redis.asyncio import Redis from typing import Dict, Any, AsyncGenerator import os @@ -9,6 +10,7 @@ import requests import logging from datetime import datetime +import uuid from config_manager import ConfigManager from billing import StrictRedisBilling, BillingError @@ -22,12 +24,19 @@ ModelInfo, OpenAIModel, OpenAIModelList, - InferenceAPIError, - PricingError + InferenceAPIError ) from inference_adapters import ModelAdapter -app = FastAPI() +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + await api.init() + yield + # Shutdown (if needed) + await redis.close() + +app = FastAPI(lifespan=lifespan) security = HTTPBearer() logger = logging.getLogger(__name__) @@ -51,35 +60,27 @@ async def init(self): async def validate_session(self, credentials: HTTPAuthorizationCredentials) -> str: if not credentials: - logger.error("No credentials provided") raise InferenceAPIError(401, "Missing session ID") session_id = credentials.credentials - logger.info(f"Validating session: {session_id}") - + try: balance = await self.billing.balance(session_id) - logger.info(f"Session {session_id} balance: {balance}") if balance is None: - logger.error(f"No balance found for session {session_id}") raise InferenceAPIError(401, "Invalid session") min_balance = await self.billing.min_balance() - logger.info(f"Minimum balance required: {min_balance}") if balance < min_balance: - logger.warning(f"Insufficient balance: {balance} < {min_balance}") raise InferenceAPIError(402, "Insufficient balance") return session_id except BillingError as e: - logger.error(f"Billing error during validation: {e}") raise InferenceAPIError(500, "Internal service error") except Exception as e: - logger.error(f"Unexpected error during validation: {e}") raise InferenceAPIError(500, f"Internal service error: {e}") async def get_model_config(self, model_id: str) -> tuple[dict[str, Any], dict[str, Any]]: @@ -94,97 +95,66 @@ async def get_model_config(self, model_id: str) -> tuple[dict[str, Any], dict[st raise InferenceAPIError(400, f"Unknown model: {model_id}") def get_token_prices(self, pricing_config: Dict[str, Any]) -> tuple[float, float]: - try: - pricing_type = pricing_config['type'] - - if pricing_type == 'fixed': - return ( - pricing_config['input_price'], - pricing_config['output_price'] - ) - - elif pricing_type == 'cost_plus': - return ( - pricing_config['backend_input'] + pricing_config['input_markup'], - pricing_config['backend_output'] + pricing_config['output_markup'] - ) - - elif pricing_type == 'multiplier': - return ( - pricing_config['backend_input'] * pricing_config['input_multiplier'], - pricing_config['backend_output'] * pricing_config['output_multiplier'] - ) - - else: - raise PricingError(f"Unknown pricing type: {pricing_type}") - - except KeyError as e: - raise PricingError(f"Missing required pricing field: {e}") + pricing_type = pricing_config['type'] + + if pricing_type == 'fixed': + return (pricing_config['input_price'], pricing_config['output_price']) + elif pricing_type == 'cost_plus': + return ( + pricing_config['backend_input'] + pricing_config['input_markup'], + pricing_config['backend_output'] + pricing_config['output_markup'] + ) + elif pricing_type == 'multiplier': + return ( + pricing_config['backend_input'] * pricing_config['input_multiplier'], + pricing_config['backend_output'] * pricing_config['output_multiplier'] + ) + else: + raise InferenceAPIError(500, f"Unknown pricing type: {pricing_type}") def calculate_cost(self, pricing_config: Dict[str, Any], input_tokens: int, output_tokens: int) -> float: - try: - input_price, output_price = self.get_token_prices(pricing_config) - - total_cost = ( - (input_tokens * input_price) + - (output_tokens * output_price) - ) / 1_000_000 # Convert to millions of tokens - - return total_cost - - except Exception as e: - raise PricingError(f"Failed to calculate cost: {e}") + input_price, output_price = self.get_token_prices(pricing_config) + return ((input_tokens * input_price) + (output_tokens * output_price)) / 1_000_000 def estimate_max_cost(self, pricing_config: Dict[str, Any], input_tokens: int, max_output_tokens: int) -> float: return self.calculate_cost(pricing_config, input_tokens, max_output_tokens) def count_input_tokens(self, request: ChatCompletionRequest) -> int: - # TODO: Implement proper tokenization based on model - return sum(len(msg.content or "") // 4 for msg in request.messages) # Rough estimate + return sum(len(msg.content or "") // 4 for msg in request.messages) def query_backend(self, endpoint_config: Dict[str, Any], model_config: Dict[str, Any], request: ChatCompletionRequest) -> ChatCompletion: + try: + data, headers = ModelAdapter.prepare_request(endpoint_config, model_config, request) + + response = requests.post( + endpoint_config['url'], + headers=headers, + json=data, + timeout=30 + ) + try: - data, headers = ModelAdapter.prepare_request(endpoint_config, model_config, request) - - logger.info(f"Sending request to backend: {endpoint_config['url']}") - logger.debug(f"Request headers: {headers}") - logger.debug(f"Request data: {data}") - - response = requests.post( - endpoint_config['url'], - headers=headers, - json=data, - timeout=30 - ) - - try: - response.raise_for_status() - except requests.exceptions.HTTPError as e: - if response.status_code == 400: - if endpoint_config['api_type'] == 'openrouter': - error_body = response.json() - if 'error' in error_body and 'message' in error_body['error']: - raise InferenceAPIError(400, error_body['error']['message']) - raise - - result = response.json() - logger.info(f"Raw backend response: {result}") - - completion = ModelAdapter.parse_response( - api_type=endpoint_config['api_type'], - response=result, - model=request.model, - request_id=request.request_id - ) - - # Log the final response we're sending back - logger.info(f"Sending completion response: {completion.dict()}") - - return completion + response.raise_for_status() + except requests.exceptions.HTTPError as e: + if response.status_code == 400 and endpoint_config['api_type'] == 'openrouter': + error_body = response.json() + if 'error' in error_body and 'message' in error_body['error']: + raise InferenceAPIError(400, error_body['error']['message']) + raise + + result = response.json() + + completion = ModelAdapter.parse_response( + api_type=endpoint_config['api_type'], + response=result, + model=request.model, + request_id=request.request_id + ) + + return completion - except requests.exceptions.RequestException as e: - logger.error(f"Backend request failed: {e}") - raise InferenceAPIError(502, "Backend service error") + except requests.exceptions.RequestException as e: + raise InferenceAPIError(502, "Backend service error") async def list_models(self) -> Dict[str, ModelInfo]: config = await self.config_manager.load_config() @@ -202,70 +172,69 @@ async def list_models(self) -> Dict[str, ModelInfo]: return models async def list_openai_models(self) -> OpenAIModelList: - """OpenAI-compatible /v1/models endpoint""" config = await self.config_manager.load_config() models = [] created = int(datetime.now().timestamp()) for endpoint_id, endpoint in config['inference']['endpoints'].items(): for model in endpoint['models']: - models.append({ - "id": model['id'], - "created": model.get('created', created), - "owned_by": endpoint.get('provider', 'orchid-labs') - }) + models.append(OpenAIModel( + id=model['id'], + created=model.get('created', created), + owned_by=endpoint.get('provider', 'orchid-labs') + )) return OpenAIModelList(data=models) async def create_stream_chunks(self, completion: ChatCompletion) -> AsyncGenerator[str, None]: - """Convert a completion into a stream of chunks""" - # Create delta chunk - chunk = ChatCompletionChunk( + first_chunk = ChatCompletionChunk( id=completion.id, model=completion.model, choices=[ ChatChoice( index=0, - message=Message( - role="assistant", - content=completion.choices[0].message.content - ), - finish_reason=None # Will be included in final chunk + delta={"role": "assistant"}, + finish_reason=None ) ] ) + yield f"data: {json.dumps(first_chunk.dict(exclude_none=True))}\n\n" - # Send the chunk - yield f"data: {json.dumps(chunk.dict())}\n\n" + content = completion.choices[0].message.content + if content: + content_chunk = ChatCompletionChunk( + id=completion.id, + model=completion.model, + choices=[ + ChatChoice( + index=0, + delta={"content": content}, + finish_reason=None + ) + ] + ) + yield f"data: {json.dumps(content_chunk.dict(exclude_none=True))}\n\n" - # Send final chunk with finish_reason final_chunk = ChatCompletionChunk( id=completion.id, model=completion.model, choices=[ ChatChoice( index=0, - message=Message( - role="assistant", - content=None # Content is empty in final chunk - ), + delta={}, finish_reason=completion.choices[0].finish_reason ) ] ) - yield f"data: {json.dumps(final_chunk.dict())}\n\n" - - # Send the final [DONE] message + yield f"data: {json.dumps(final_chunk.dict(exclude_none=True))}\n\n" yield "data: [DONE]\n\n" async def stream_inference(self, request: ChatCompletionRequest, session_id: str) -> AsyncGenerator[str, None]: - """Handle streaming inference requests""" try: completion = await self.handle_inference(request, session_id) async for chunk in self.create_stream_chunks(completion): yield chunk except Exception as e: - logger.error(f"Error during streaming: {e}") error_payload = { "error": { "message": str(e), @@ -293,7 +262,6 @@ async def handle_inference(self, request: ChatCompletionRequest, session_id: str balance = await self.billing.balance(session_id) if balance < max_cost: - logger.warning(f"Insufficient balance for max cost: {balance} < {max_cost}") await self.redis.publish( f"billing:balance:updates:{session_id}", str(balance) @@ -301,7 +269,6 @@ async def handle_inference(self, request: ChatCompletionRequest, session_id: str raise InferenceAPIError(402, "Insufficient balance") await self.billing.debit(session_id, amount=max_cost) - logger.info(f"Reserved {max_cost} tokens from balance") try: result = self.query_backend(endpoint_config, model_config, request) @@ -312,56 +279,49 @@ async def handle_inference(self, request: ChatCompletionRequest, session_id: str result.usage.completion_tokens ) - logger.info(f"Actual cost: {actual_cost} (reserved: {max_cost})") - if actual_cost < max_cost: refund = max_cost - actual_cost await self.billing.credit(session_id, amount=refund) - logger.info(f"Refunded excess reservation: {refund}") return result except Exception as e: - logger.error(f"Error during inference: {e}") await self.billing.credit(session_id, amount=max_cost) raise - except BillingError as e: - logger.error(f"Billing error: {e}") - raise InferenceAPIError(500, "Internal service error") - except PricingError as e: - logger.error(f"Pricing error: {e}") - raise InferenceAPIError(500, f"Pricing configuration error: {e}") except Exception as e: - logger.error(f"Unexpected error: {e}") + if isinstance(e, InferenceAPIError): + raise raise InferenceAPIError(500, str(e)) -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + return JSONResponse( + status_code=exc.status_code, + content={ + "error": { + "message": exc.detail, + "type": "invalid_request_error", + "code": str(exc.status_code) + } + } + ) + +logging.basicConfig(level=logging.INFO) redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379') redis = Redis.from_url(redis_url, decode_responses=True) api = InferenceAPI(redis) -@app.on_event("startup") -async def startup(): - await api.init() - @app.post("/v1/chat/completions") async def chat_completion( request: ChatCompletionRequest, credentials: HTTPAuthorizationCredentials = Depends(security) ): - logger.info(f"Received chat completion request: {request.dict()}") - logger.info(f"Auth: {credentials.scheme} {credentials.credentials}") - try: session_id = await api.validate_session(credentials) if request.stream: - logger.info("Streaming response requested") return StreamingResponse( api.stream_inference(request, session_id), media_type="text/event-stream", @@ -372,16 +332,12 @@ async def chat_completion( "Transfer-Encoding": "chunked", } ) - else: - logger.info("Normal response requested") - result = await api.handle_inference(request, session_id) - logger.info(f"Final API response: {result.dict()}") - return result + + result = await api.handle_inference(request, session_id) + return result.dict(exclude_none=True) except InferenceAPIError as e: - logger.error(f"Inference API error: {e.status_code} - {e.detail}") raise HTTPException(status_code=e.status_code, detail=e.detail) except Exception as e: - logger.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/inference") @@ -393,19 +349,15 @@ async def inference( @app.get("/v1/models") async def list_openai_models(): - """OpenAI-compatible models list endpoint""" try: return await api.list_openai_models() except Exception as e: - logger.error(f"Failed to list OpenAI-compatible models: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/v1/inference/models") async def list_inference_models(): - """Detailed models list for inference""" try: models = await api.list_models() return models except Exception as e: - logger.error(f"Failed to list inference models: {e}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/gai-backend/inference_models.py b/gai-backend/inference_models.py new file mode 100644 index 000000000..0f9ed5047 --- /dev/null +++ b/gai-backend/inference_models.py @@ -0,0 +1,80 @@ +from pydantic import BaseModel, Field +from typing import Optional, Dict, Any, List, Literal, Union +from datetime import datetime +import uuid + +class Message(BaseModel): + role: Literal["system", "user", "assistant", "function"] + content: Optional[str] = None + name: Optional[str] = None + function_call: Optional[Dict[str, Any]] = None + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[Message] + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stream: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = None + presence_penalty: Optional[float] = 0 + frequency_penalty: Optional[float] = 0 + user: Optional[str] = None + request_id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4())) + tools: Optional[List[Dict[str, Any]]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + functions: Optional[List[Dict[str, Any]]] = None + function_call: Optional[Union[str, Dict[str, Any]]] = None + +class ChatChoice(BaseModel): + index: int + message: Optional[Message] = None + delta: Optional[Dict[str, Any]] = None + finish_reason: Optional[str] = None + +class Usage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + +class ChatCompletion(BaseModel): + id: str + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(datetime.now().timestamp())) + model: str + choices: List[ChatChoice] + usage: Usage + system_fingerprint: Optional[str] = None + +class ChatCompletionChunk(BaseModel): + id: str + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(datetime.now().timestamp())) + model: str + choices: List[ChatChoice] + system_fingerprint: Optional[str] = None + +class ModelInfo(BaseModel): + id: str + name: str + api_type: Literal["openai", "anthropic", "openrouter"] + endpoint: str + +class OpenAIModel(BaseModel): + id: str + object: str = "model" + created: int + owned_by: str + +class OpenAIModelList(BaseModel): + object: str = "list" + data: List[OpenAIModel] + +class InferenceAPIError(Exception): + def __init__(self, status_code: int, detail: str): + self.status_code = status_code + self.detail = detail + +class PricingError(Exception): + pass