From 563ef9ed740512482ef922c9fb998097b82a9f5d Mon Sep 17 00:00:00 2001 From: Dan Montgomery Date: Fri, 8 Nov 2024 15:36:27 -0500 Subject: [PATCH] Split inference into separate http endpoint. --- gai-backend/account.py | 141 +++++++++++ gai-backend/billing.py | 139 ++++++++--- gai-backend/config_manager.py | 140 +++++++++++ gai-backend/inference_api.py | 415 +++++++++++++++++++++++++++++++++ gai-backend/jobs.py | 100 -------- gai-backend/lottery.py | 243 ++++++++++++------- gai-backend/lottery0.abi | 1 + gai-backend/lottery1.abi | 1 + gai-backend/payment_handler.py | 74 ++++++ gai-backend/server.py | 363 +++++++++++++++++----------- gai-backend/test.py | 307 ++++++++++++++++++++++++ gai-backend/ticket.py | 227 ++++++++---------- 12 files changed, 1680 insertions(+), 471 deletions(-) create mode 100644 gai-backend/account.py create mode 100644 gai-backend/config_manager.py create mode 100644 gai-backend/inference_api.py delete mode 100644 gai-backend/jobs.py create mode 100644 gai-backend/lottery0.abi create mode 100644 gai-backend/lottery1.abi create mode 100644 gai-backend/payment_handler.py create mode 100644 gai-backend/test.py diff --git a/gai-backend/account.py b/gai-backend/account.py new file mode 100644 index 000000000..3d46876f2 --- /dev/null +++ b/gai-backend/account.py @@ -0,0 +1,141 @@ +from web3 import Web3 +from decimal import Decimal +import secrets +from eth_account.messages import encode_defunct +from lottery import Lottery +from typing import Optional, Dict, Tuple + +class OrchidAccountError(Exception): + """Base class for Orchid account errors""" + pass + +class InvalidAddressError(OrchidAccountError): + """Invalid Ethereum address""" + pass + +class InvalidAmountError(OrchidAccountError): + """Invalid payment amount""" + pass + +class SigningError(OrchidAccountError): + """Error signing transaction or message""" + pass + +class OrchidAccount: + def __init__(self, + lottery: Lottery, + funder_address: str, + private_key: str): + try: + self.lottery = lottery + self.web3 = lottery.web3 + self.funder = self.web3.to_checksum_address(funder_address) + self.key = private_key + self.signer = self.web3.eth.account.from_key(private_key).address + except ValueError as e: + raise InvalidAddressError(f"Invalid address format: {e}") + except Exception as e: + raise OrchidAccountError(f"Failed to initialize account: {e}") + + def create_ticket(self, + amount: int, + recipient: str, + commitment: str, + token_addr: str = "0x0000000000000000000000000000000000000000" + ) -> str: + """ + Create signed nanopayment ticket + + Args: + amount: Payment amount in wei + recipient: Recipient address + commitment: Random commitment hash + token_addr: Token contract address + + Returns: + Serialized ticket string + """ + try: + if amount <= 0: + raise InvalidAmountError("Amount must be positive") + + recipient = self.web3.to_checksum_address(recipient) + token_addr = self.web3.to_checksum_address(token_addr) + + # Random nonce + nonce = secrets.randbits(128) + + # Pack ticket data + packed0 = amount | (nonce << 128) + ratio = 0xffffffffffffffff # Always create winning tickets for testing + packed1 = (ratio << 161) | (0 << 160) # v=0 + + # Sign ticket + message_hash = self._get_ticket_hash( + token_addr, + recipient, + commitment, + packed0, + packed1 + ) + + sig = self.web3.eth.account.sign_message( + encode_defunct(message_hash), + private_key=self.key + ) + + # Adjust v and update packed1 + v = sig.v - 27 + packed1 = packed1 | v + + # Format as hex strings + return ( + hex(packed0)[2:].zfill(64) + + hex(packed1)[2:].zfill(64) + + hex(sig.r)[2:].zfill(64) + + hex(sig.s)[2:].zfill(64) + ) + + except OrchidAccountError: + raise + except Exception as e: + raise SigningError(f"Failed to create ticket: {e}") + + def _get_ticket_hash(self, + token_addr: str, + recipient: str, + commitment: str, + packed0: int, + packed1: int) -> bytes: + try: + return Web3.solidity_keccak( + ['bytes1', 'bytes1', 'address', 'bytes32', 'address', 'address', + 'bytes32', 'uint256', 'uint256', 'bytes32'], + [b'\x19', b'\x00', + self.lottery.contract_addr, + b'\x00' * 31 + b'\x64', # Chain ID + token_addr, + recipient, + Web3.solidity_keccak(['bytes32'], [commitment]), + packed0, + packed1 >> 1, # Remove v + b'\x00' * 32] # Empty data field + ) + except Exception as e: + raise SigningError(f"Failed to create message hash: {e}") + + async def get_balance(self, + token_addr: str = "0x0000000000000000000000000000000000000000" + ) -> Tuple[float, float]: + try: + balance, escrow = await self.lottery.check_balance( + token_addr, + self.funder, + self.signer + ) + return ( + self.lottery.wei_to_token(balance), + self.lottery.wei_to_token(escrow) + ) + except Exception as e: + raise OrchidAccountError(f"Failed to get balance: {e}") diff --git a/gai-backend/billing.py b/gai-backend/billing.py index 5b9421d31..b4d410bd0 100644 --- a/gai-backend/billing.py +++ b/gai-backend/billing.py @@ -1,33 +1,116 @@ import json +from redis.asyncio import Redis +import redis +from decimal import Decimal +from typing import Optional, Dict +import asyncio -disconnect_threshold = -0.002 +class BillingError(Exception): + """Base class for billing errors that should terminate the connection""" + pass -def invoice(amt): - return json.dumps({'type': 'invoice', 'amount': amt}) +class RedisConnectionError(BillingError): + """Redis connection or operation failed""" + pass -class Billing: - def __init__(self, prices): - self.ledger = {} - self.prices = prices - - def credit(self, id, type=None, amount=0): - self.adjust(id, type, amount, 1) +class InconsistentStateError(BillingError): + """Billing state became inconsistent""" + pass - def debit(self, id, type=None, amount=0): - self.adjust(id, type, amount, -1) - - def adjust(self, id, type, amount, sign): - amount_ = self.prices[type] if type is not None else amount - if id in self.ledger: - self.ledger[id] = self.ledger[id] + sign * amount_ - else: - self.ledger[id] = sign * amount_ - - def min_balance(self): - return 2 * (self.prices['invoice'] + self.prices['payment']) - - def balance(self, id): - if id in self.ledger: - return self.ledger[id] - else: - return 0 \ No newline at end of file +class StrictRedisBilling: + def __init__(self, redis: Redis): + self.redis = redis + + async def init(self): + try: + await self.redis.ping() + except Exception as e: + raise RedisConnectionError(f"Failed to connect to Redis: {e}") + + def _get_client_key(self, client_id: str) -> str: + return f"billing:balance:{client_id}" + + def _get_update_channel(self, client_id: str) -> str: + return f"billing:balance:updates:{client_id}" + + async def credit(self, id: str, type: Optional[str] = None, amount: float = 0): + await self.adjust(id, type, amount, 1) + + async def debit(self, id: str, type: Optional[str] = None, amount: float = 0): + await self.adjust(id, type, amount, -1) + + async def adjust(self, id: str, type: Optional[str], amount: float, sign: int): + key = self._get_client_key(id) + channel = self._get_update_channel(id) + + # Get amount from pricing if type is provided + amount_ = amount + if type is not None: + # Get price from config + config_data = await self.redis.get("config:data") + if not config_data: + raise BillingError("No configuration found") + config = json.loads(config_data) + price = config['billing']['prices'].get(type) + if price is None: + raise BillingError(f"Unknown price type: {type}") + amount_ = price + + try: + async with self.redis.pipeline() as pipe: + while True: + try: + await pipe.watch(key) + current = await self.redis.get(key) + try: + current_balance = Decimal(current) if current else Decimal('0') + except (TypeError, ValueError) as e: + raise InconsistentStateError(f"Invalid balance format in Redis: {e}") + + new_balance = current_balance + Decimal(str(sign * amount_)) + + pipe.multi() + await pipe.set(key, str(new_balance)) + await pipe.publish(channel, str(new_balance)) + await pipe.execute() + return + + except redis.WatchError: + continue + + except Exception as e: + raise RedisConnectionError(f"Redis transaction failed: {e}") + + except BillingError: + raise + except Exception as e: + raise RedisConnectionError(f"Unexpected Redis error: {e}") + + async def balance(self, id: str) -> float: + try: + key = self._get_client_key(id) + balance = await self.redis.get(key) + + if balance is None: + return 0 + + try: + return float(Decimal(balance)) + except (TypeError, ValueError) as e: + raise InconsistentStateError(f"Invalid balance format: {e}") + + except BillingError: + raise + except Exception as e: + raise RedisConnectionError(f"Failed to get balance: {e}") + + async def min_balance(self) -> float: + try: + config_data = await self.redis.get("config:data") + if not config_data: + raise BillingError("No configuration found") + config = json.loads(config_data) + prices = config['billing']['prices'] + return 2 * (prices['invoice'] + prices['payment']) + except Exception as e: + raise BillingError(f"Failed to calculate minimum balance: {e}") diff --git a/gai-backend/config_manager.py b/gai-backend/config_manager.py new file mode 100644 index 000000000..b1122bbb8 --- /dev/null +++ b/gai-backend/config_manager.py @@ -0,0 +1,140 @@ +from typing import Dict, Any, Optional +from redis.asyncio import Redis +import json +import time +import os + +class ConfigError(Exception): + """Raised when config operations fail""" + pass + +class ConfigManager: + def __init__(self, redis: Redis): + self.redis = redis + self.last_load_time = 0 + self.current_config = {} + + async def load_from_file(self, config_path: str) -> Dict[str, Any]: + try: + with open(config_path, 'r') as f: + return json.load(f) + except FileNotFoundError: + raise ConfigError(f"Config file not found: {config_path}") + except json.JSONDecodeError as e: + raise ConfigError(f"Invalid JSON in config file: {e}") + except Exception as e: + raise ConfigError(f"Failed to load config file: {e}") + + def process_config(self, config: Dict[str, Any]) -> Dict[str, Any]: + if 'inference' not in config: + raise ConfigError("Missing required 'inference' section") + + if 'endpoints' not in config['inference']: + raise ConfigError("Missing required 'endpoints' in inference config") + + endpoints = config['inference']['endpoints'] + if not endpoints: + raise ConfigError("No inference endpoints configured") + + total_models = 0 + + for endpoint_id, endpoint in endpoints.items(): + required_fields = ['api_type', 'url', 'api_key', 'models'] + missing = [field for field in required_fields if field not in endpoint] + if missing: + raise ConfigError(f"Endpoint {endpoint_id} missing required fields: {', '.join(missing)}") + if not isinstance(endpoint['models'], list): + raise ConfigError(f"Endpoint {endpoint_id} 'models' must be a list") + if not endpoint['models']: + raise ConfigError(f"Endpoint {endpoint_id} has no models configured") + + total_models += len(endpoint['models']) + + for model in endpoint['models']: + required_model_fields = ['id', 'pricing'] + missing = [field for field in required_model_fields if field not in model] + if missing: + raise ConfigError(f"Model in endpoint {endpoint_id} missing required fields: {', '.join(missing)}") + if 'params' not in model: + model['params'] = {} + + pricing = model['pricing'] + if 'type' not in pricing: + raise ConfigError(f"Model {model['id']} missing required pricing type") + + required_pricing_fields = { + 'fixed': ['input_price', 'output_price'], + 'cost_plus': ['backend_input', 'backend_output', 'input_markup', 'output_markup'], + 'multiplier': ['backend_input', 'backend_output', 'input_multiplier', 'output_multiplier'] + } + + if pricing['type'] not in required_pricing_fields: + raise ConfigError(f"Invalid pricing type for model {model['id']}: {pricing['type']}") + + missing = [field for field in required_pricing_fields[pricing['type']] + if field not in pricing] + if missing: + raise ConfigError(f"Model {model['id']} pricing missing required fields: {', '.join(missing)}") + + if total_models == 0: + raise ConfigError("No models configured across all endpoints") + + return config + + async def write_config(self, config: Dict[str, Any], force: bool = False): + try: + config = self.process_config(config) + + async with self.redis.pipeline() as pipe: + if not force: + current_time = await self.redis.get("config:last_update") + if current_time and float(current_time) > self.last_load_time: + raise ValueError("Config was updated more recently by another server") + + timestamp = time.time() + await pipe.set("config:data", json.dumps(config)) + await pipe.set("config:last_update", str(timestamp)) + await pipe.execute() + + self.current_config = config + self.last_load_time = timestamp + + except Exception as e: + raise ConfigError(f"Failed to write config: {e}") + + async def load_config(self, config_path: Optional[str] = None, force_reload: bool = False) -> Dict[str, Any]: + try: + if config_path: + config = await self.load_from_file(config_path) + await self.write_config(config, force=True) + return config + + timestamp = await self.redis.get("config:last_update") + + if not force_reload and timestamp and self.last_load_time >= float(timestamp): + return self.current_config + + config_data = await self.redis.get("config:data") + if not config_data: + raise ConfigError("No configuration found in Redis") + + config = json.loads(config_data) + config = self.process_config(config) + + self.current_config = config + self.last_load_time = float(timestamp) if timestamp else time.time() + return config + + except Exception as e: + raise ConfigError(f"Failed to load config: {e}") + + async def check_for_updates(self) -> bool: + try: + timestamp = await self.redis.get("config:last_update") + if timestamp and float(timestamp) > self.last_load_time: + await self.load_config(force_reload=True) + return True + return False + + except Exception as e: + raise ConfigError(f"Failed to check for updates: {e}") diff --git a/gai-backend/inference_api.py b/gai-backend/inference_api.py new file mode 100644 index 000000000..b5195b3ab --- /dev/null +++ b/gai-backend/inference_api.py @@ -0,0 +1,415 @@ +from fastapi import FastAPI, HTTPException, Depends +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi.middleware.cors import CORSMiddleware +from redis.asyncio import Redis +from pydantic import BaseModel, Field +from typing import Optional, Dict, Any, Tuple, List, Literal +import json +import os +import requests +from config_manager import ConfigManager +from billing import StrictRedisBilling, BillingError +import logging + +app = FastAPI() +security = HTTPBearer() +logger = logging.getLogger(__name__) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +class Message(BaseModel): + role: Literal["system", "user", "assistant"] + content: str + name: Optional[str] = None + +class InferenceRequest(BaseModel): + messages: List[Message] + model: str + params: Optional[Dict[str, Any]] = None + request_id: Optional[str] = None + +class ModelInfo(BaseModel): + id: str + name: str + api_type: Literal["openai", "anthropic", "openrouter"] + endpoint: str + +class InferenceAPIError(Exception): + def __init__(self, status_code: int, detail: str): + self.status_code = status_code + self.detail = detail + +class PricingError(Exception): + """Raised when pricing calculation fails""" + pass + +class InferenceAPI: + def __init__(self, redis: Redis): + self.redis = redis + self.config_manager = ConfigManager(redis) + self.billing = StrictRedisBilling(redis) + + async def init(self): + await self.billing.init() + await self.config_manager.load_config() + + async def list_models(self) -> Dict[str, ModelInfo]: + config = await self.config_manager.load_config() + models = {} + + for endpoint_id, endpoint in config['inference']['endpoints'].items(): + for model in endpoint['models']: + models[model['id']] = ModelInfo( + id=model['id'], + name=model.get('display_name', model['id']), + api_type=endpoint['api_type'], + endpoint=endpoint_id + ) + + return models + + async def validate_session(self, credentials: HTTPAuthorizationCredentials) -> str: + if not credentials: + logger.error("No credentials provided") + raise InferenceAPIError(401, "Missing session ID") + + session_id = credentials.credentials + logger.info(f"Validating session: {session_id}") + + try: + balance = await self.billing.balance(session_id) + logger.info(f"Session {session_id} balance: {balance}") + + if balance is None: + logger.error(f"No balance found for session {session_id}") + raise InferenceAPIError(401, "Invalid session") + + min_balance = await self.billing.min_balance() + logger.info(f"Minimum balance required: {min_balance}") + + if balance < min_balance: + logger.warning(f"Insufficient balance: {balance} < {min_balance}") + raise InferenceAPIError(402, "Insufficient balance") + + return session_id + + except BillingError as e: + logger.error(f"Billing error during validation: {e}") + raise InferenceAPIError(500, "Internal service error") + + except Exception as e: + logger.error(f"Unexpected error during validation: {e}") + raise InferenceAPIError(500, f"Internal service error: {e}") + + async def get_model_config(self, model_id: str) -> Tuple[Dict[str, Any], Dict[str, Any]]: + config = await self.config_manager.load_config() + endpoints = config['inference']['endpoints'] + + for endpoint_id, endpoint in endpoints.items(): + for model in endpoint['models']: + if model['id'] == model_id: + return endpoint, model + + raise InferenceAPIError(400, f"Unknown model: {model_id}") + + def prepare_request(self, endpoint_config: Dict[str, Any], model_config: Dict[str, Any], request: InferenceRequest) -> Tuple[Dict[str, Any], Dict[str, str]]: + params = { + **(endpoint_config.get('params', {})), + **(model_config.get('params', {})), + **(request.params or {}) + } + + headers = {"Content-Type": "application/json"} + api_type = endpoint_config['api_type'] + + if not (api_key := endpoint_config.get('api_key')): + logger.error("No API key configured for endpoint") + raise InferenceAPIError(500, "Backend authentication not configured") + + data: Dict[str, Any] = {} + + if api_type == 'openai': + headers["Authorization"] = f"Bearer {api_key}" + data = { + 'model': model_config['id'], + 'messages': [msg.dict(exclude_none=True) for msg in request.messages] + } + if 'max_tokens' in (request.params or {}): + data['max_tokens'] = params['max_tokens'] + + elif api_type == 'openrouter': + headers["Authorization"] = f"Bearer {api_key}" + data = { + 'model': model_config['id'], + 'messages': [msg.dict(exclude_none=True) for msg in request.messages] + } + + if 'max_tokens' in (request.params or {}): + user_max_tokens = params['max_tokens'] + config_max_tokens = model_config.get('params', {}).get('max_tokens') + + if config_max_tokens and user_max_tokens > config_max_tokens: + raise InferenceAPIError(400, f"Requested max_tokens {user_max_tokens} exceeds model limit {config_max_tokens}") + + prompt_tokens = self.count_input_tokens(request) + if config_max_tokens and (prompt_tokens + user_max_tokens) > config_max_tokens: + raise InferenceAPIError(400, + f"Combined prompt ({prompt_tokens}) and max_tokens ({user_max_tokens}) " + f"exceeds model context limit {config_max_tokens}") + + data['max_tokens'] = user_max_tokens + + elif api_type == 'anthropic': + headers["x-api-key"] = api_key + headers["anthropic-version"] = "2023-06-01" + system_message = next((msg.content for msg in request.messages if msg.role == "system"), None) + conversation = [msg for msg in request.messages if msg.role != "system"] + + data = { + 'model': model_config['id'], + 'messages': [{'role': msg.role, 'content': msg.content} for msg in conversation], + 'max_tokens': params.get('max_tokens', 4096) + } + if system_message: + data['system'] = system_message + + else: + raise InferenceAPIError(500, f"Unsupported API type: {api_type}") + + for k, v in params.items(): + if k != 'max_tokens' and k not in data: + data[k] = v + + return data, headers + + def parse_response(self, api_type: str, response: Dict[str, Any], request_id: Optional[str] = None) -> Dict[str, Any]: + try: + base_response = { + 'request_id': request_id, + } + + if api_type in ['openai', 'openrouter']: # OpenRouter follows OpenAI response format + return { + **base_response, + 'response': response['choices'][0]['message']['content'], + 'usage': response['usage'] + } + elif api_type == 'anthropic': + return { + **base_response, + 'response': response['content'][0]['text'], + 'usage': { + 'prompt_tokens': response['usage']['input_tokens'], + 'completion_tokens': response['usage']['output_tokens'], + 'total_tokens': response['usage']['input_tokens'] + response['usage']['output_tokens'] + } + } + else: + raise InferenceAPIError(500, f"Unsupported API type: {api_type}") + except KeyError as e: + logger.error(f"Failed to parse {api_type} response: {e}") + logger.error(f"Response: {response}") + logger.debug(f"Raw response: {response}") + raise InferenceAPIError(502, f"Invalid {api_type} response format") + + def query_backend(self, endpoint_config: Dict[str, Any], model_config: Dict[str, Any], request: InferenceRequest) -> Dict[str, Any]: + try: + data, headers = self.prepare_request(endpoint_config, model_config, request) + + logger.info(f"Sending request to backend: {endpoint_config['url']}") + logger.debug(f"Request headers: {headers}") + logger.debug(f"Request data: {data}") + + response = requests.post( + endpoint_config['url'], + headers=headers, + json=data, + timeout=30 + ) + + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + if response.status_code == 400: + if endpoint_config['api_type'] == 'openrouter': + error_body = response.json() + if 'error' in error_body and 'message' in error_body['error']: + raise InferenceAPIError(400, error_body['error']['message']) + raise + + result = response.json() + logger.debug(f"Raw backend response: {result}") + + return self.parse_response(endpoint_config['api_type'], result, request.request_id) + + except requests.exceptions.RequestException as e: + logger.error(f"Backend request failed: {e}") + raise InferenceAPIError(502, "Backend service error") + + def get_token_prices(self, pricing_config: Dict[str, Any]) -> Tuple[float, float]: + try: + pricing_type = pricing_config['type'] + + if pricing_type == 'fixed': + return ( + pricing_config['input_price'], + pricing_config['output_price'] + ) + + elif pricing_type == 'cost_plus': + return ( + pricing_config['backend_input'] + pricing_config['input_markup'], + pricing_config['backend_output'] + pricing_config['output_markup'] + ) + + elif pricing_type == 'multiplier': + return ( + pricing_config['backend_input'] * pricing_config['input_multiplier'], + pricing_config['backend_output'] * pricing_config['output_multiplier'] + ) + + else: + raise PricingError(f"Unknown pricing type: {pricing_type}") + + except KeyError as e: + raise PricingError(f"Missing required pricing field: {e}") + + def calculate_cost(self, pricing_config: Dict[str, Any], input_tokens: int, output_tokens: int) -> float: + try: + input_price, output_price = self.get_token_prices(pricing_config) + + total_cost = ( + (input_tokens * input_price) + + (output_tokens * output_price) + ) / 1_000_000 # Convert to millions of tokens + + return total_cost + + except Exception as e: + raise PricingError(f"Failed to calculate cost: {e}") + + def estimate_max_cost(self, pricing_config: Dict[str, Any], input_tokens: int, max_output_tokens: int) -> float: + return self.calculate_cost(pricing_config, input_tokens, max_output_tokens) + + def count_input_tokens(self, request: InferenceRequest) -> int: + # TODO: Implement proper tokenization based on model + return sum(len(msg.content) // 4 for msg in request.messages) # Rough estimate + + async def handle_inference( + self, + request: InferenceRequest, + session_id: str + ) -> Dict[str, Any]: + try: + # Get endpoint and model configs - model is now required + endpoint_config, model_config = await self.get_model_config(request.model) + + # Calculate maximum possible cost + input_tokens = self.count_input_tokens(request) + max_output_tokens = model_config.get('params', {}).get( + 'max_tokens', + endpoint_config.get('params', {}).get('max_tokens', 4096) + ) + + max_cost = self.estimate_max_cost( + model_config['pricing'], + input_tokens, + max_output_tokens + ) + + balance = await self.billing.balance(session_id) + if balance < max_cost: + logger.warning(f"Insufficient balance for max cost: {balance} < {max_cost}") + await self.redis.publish( + f"billing:balance:updates:{session_id}", + str(balance) + ) + raise InferenceAPIError(402, "Insufficient balance") + + await self.billing.debit(session_id, amount=max_cost) + logger.info(f"Reserved {max_cost} tokens from balance") + + try: + result = self.query_backend(endpoint_config, model_config, request) + + actual_cost = self.calculate_cost( + model_config['pricing'], + result['usage']['prompt_tokens'], + result['usage']['completion_tokens'] + ) + + logger.info(f"Actual cost: {actual_cost} (reserved: {max_cost})") + + if actual_cost < max_cost: + refund = max_cost - actual_cost + await self.billing.credit(session_id, amount=refund) + logger.info(f"Refunded excess reservation: {refund}") + + return result + + except Exception as e: + logger.error(f"Error during inference: {e}") + await self.billing.credit(session_id, amount=max_cost) + raise + + except BillingError as e: + logger.error(f"Billing error: {e}") + raise InferenceAPIError(500, "Internal service error") + except PricingError as e: + logger.error(f"Pricing error: {e}") + raise InferenceAPIError(500, f"Pricing configuration error: {e}") + except Exception as e: + logger.error(f"Unexpected error: {e}") + raise InferenceAPIError(500, str(e)) + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379') +redis = Redis.from_url(redis_url, decode_responses=True) +api = InferenceAPI(redis) + +@app.on_event("startup") +async def startup(): + await api.init() + +@app.post("/v1/chat/completions") +async def chat_completion( + request: InferenceRequest, + credentials: HTTPAuthorizationCredentials = Depends(security) +): + logger.info(f"Received chat completion request with auth: {credentials.scheme} {credentials.credentials}") + try: + session_id = await api.validate_session(credentials) + return await api.handle_inference(request, session_id) + except InferenceAPIError as e: + logger.error(f"Inference API error: {e.status_code} - {e.detail}") + raise HTTPException(status_code=e.status_code, detail=e.detail) + except Exception as e: + logger.error(f"Unexpected error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/v1/inference") +async def inference( + request: InferenceRequest, + credentials: HTTPAuthorizationCredentials = Depends(security) +): + return await chat_completion(request, credentials) + +@app.get("/v1/models") +async def list_models(): + """List available inference models""" + try: + models = await api.list_models() + return models + except Exception as e: + logger.error(f"Failed to list models: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/gai-backend/jobs.py b/gai-backend/jobs.py deleted file mode 100644 index 6ec93091f..000000000 --- a/gai-backend/jobs.py +++ /dev/null @@ -1,100 +0,0 @@ -import asyncio -import random -import datetime -import json -import requests - -import websockets - - -class jobs: - def __init__(self, model, url, llmkey, llmparams, api='openai'): - self.queue = asyncio.PriorityQueue() - self.sessions = {} - self.model = model - self.url = url - self.llmkey = llmkey - self.llmparams = llmparams - self.api = api - - def get_queues(self, id): - squeue = asyncio.Queue(maxsize=10) - rqueue = asyncio.Queue(maxsize=10) - self.sessions[id] = {'send': squeue, 'recv': rqueue} - return [rqueue, squeue] - - async def add_job(self, id, bid, job): - print(f'add_job({id}, {bid}, {job})') - priority = 0.1 - await self.queue.put((priority, [id, bid, job])) - - async def process_jobs(self): - print(f"Starting process_jobs() with api: {self.api}") - while True: - priority, job_params = await self.queue.get() - id, bid, job = job_params - await self.sessions[id]['send'].put(json.dumps({'type': 'started'})) - response, reason, usage = apis[self.api](job['prompt'], self.llmparams, self.model, self.url, self.llmkey) - if response == None: - continue - await self.sessions[id]['send'].put(json.dumps({'type': 'complete', 'response': response, - 'model': self.model, 'reason': reason, - 'usage': usage})) - - -def query_openai(prompt, params, model, url, llmkey): - result = "" - data = {'model': model, 'messages': [{'role': 'user', 'content': prompt}]} - data = {**data, **params} - headers = {"Content-Type": "application/json"} - if not llmkey is None: - headers["Authorization"] = f"Bearer {llmkey}" - r = requests.post(url, data=json.dumps(data), headers=headers) - result = r.json() - if result['object'] != 'chat.completion': - print('*** process_jobs: Error from llm') - print(f"from job: {job_params}") - print(result) - return None, None, None - response = result['choices'][0]['message']['content'] - model = result['model'] - reason = result['choices'][0]['finish_reason'] - usage = result['usage'] - return response, reason, usage - -def query_gemini(prompt, params, model, url, llmkey): - data = {'contents': [{'parts': [{'text': prompt}]}]} - data = {**data, **params} - print(f"query_gemini(): data: {data}") - headers = {"Content-Type": "application/json"} - url_ = url + f"?key={llmkey}" - r = requests.post(url_, data=json.dumps(data), headers=headers) - print(r.json()) - result = r.json()['candidates'][0] - response = result['content']['parts'][0]['text'] - reason = result['finishReason'] - usage = 0 - return response, reason, usage - -def query_anthropic(prompt, params, model, url, llmkey): - data = {'model': model, 'messages': [{'role': 'user', 'content': prompt}]} - data = {**data, **params} - headers = {"Content-Type": "application/json", - "anthropic-version": "2023-06-01"} - if not llmkey is None: - headers["x-api-key"] = llmkey - r = requests.post(url, data=json.dumps(data), headers=headers) - result = r.json() - if 'content' not in result: - print('*** process_jobs: Error from llm') - print(f"from job: {job_params}") - print(result) - return None, None, None - response = result['content'][0]['text'] - model = result['model'] - reason = result['stop_reason'] - usage = result['usage'] - return response, reason, usage - - -apis = {'openai': query_openai, 'gemini': query_gemini, 'anthropic': query_anthropic} diff --git a/gai-backend/lottery.py b/gai-backend/lottery.py index c7a219777..73f456bae 100644 --- a/gai-backend/lottery.py +++ b/gai-backend/lottery.py @@ -1,82 +1,161 @@ -import json -import web3 - -from ticket import Ticket - -# Gnosis - default -rpc_url_default = 'https://rpc.gnosischain.com/' -chain_id_default = 100 - -# Polygon - default -# rpc_url_default = 'https://polygon-rpc.com' -# chain_id_default = 137 - -# Gas default -gas_amount_default = 100000 - -uint64 = pow(2, 64) - 1 # 18446744073709551615 -uint128 = pow(2, 128) - 1 # 340282366920938463463374607431768211455 - -def to_32byte_hex(val): - return web3.Web3.to_hex(web3.Web3.to_bytes(hexstr=val).rjust(32, b'\0')) - - -class Lottery: - addr_type = pow(2, 20 * 8) - 1 - contract_addr = '0x6dB8381b2B41b74E17F5D4eB82E8d5b04ddA0a82' - token = '0x' + '0' * 40 - contract_abi_str = "[ { \"inputs\": [ { \"internalType\": \"uint64\", \"name\": \"day\", \"type\": \"uint64\" } ], \"stateMutability\": \"nonpayable\", \"type\": \"constructor\" }, { \"anonymous\": false, \"inputs\": [ { \"indexed\": true, \"internalType\": \"contract IERC20\", \"name\": \"token\", \"type\": \"address\" }, { \"indexed\": true, \"internalType\": \"address\", \"name\": \"funder\", \"type\": \"address\" }, { \"indexed\": true, \"internalType\": \"address\", \"name\": \"signer\", \"type\": \"address\" } ], \"name\": \"Create\", \"type\": \"event\" }, { \"anonymous\": false, \"inputs\": [ { \"indexed\": true, \"internalType\": \"bytes32\", \"name\": \"key\", \"type\": \"bytes32\" }, { \"indexed\": false, \"internalType\": \"uint256\", \"name\": \"unlock_warned\", \"type\": \"uint256\" } ], \"name\": \"Delete\", \"type\": \"event\" }, { \"anonymous\": false, \"inputs\": [ { \"indexed\": true, \"internalType\": \"address\", \"name\": \"funder\", \"type\": \"address\" }, { \"indexed\": true, \"internalType\": \"address\", \"name\": \"recipient\", \"type\": \"address\" } ], \"name\": \"Enroll\", \"type\": \"event\" }, { \"anonymous\": false, \"inputs\": [ { \"indexed\": true, \"internalType\": \"bytes32\", \"name\": \"key\", \"type\": \"bytes32\" }, { \"indexed\": false, \"internalType\": \"uint256\", \"name\": \"escrow_amount\", \"type\": \"uint256\" } ], \"name\": \"Update\", \"type\": \"event\" }, { \"inputs\": [ { \"internalType\": \"contract IERC20\", \"name\": \"token\", \"type\": \"address\" }, { \"internalType\": \"address\", \"name\": \"recipient\", \"type\": \"address\" }, { \"components\": [ { \"internalType\": \"bytes32\", \"name\": \"data\", \"type\": \"bytes32\" }, { \"internalType\": \"bytes32\", \"name\": \"reveal\", \"type\": \"bytes32\" }, { \"internalType\": \"uint256\", \"name\": \"packed0\", \"type\": \"uint256\" }, { \"internalType\": \"uint256\", \"name\": \"packed1\", \"type\": \"uint256\" }, { \"internalType\": \"bytes32\", \"name\": \"r\", \"type\": \"bytes32\" }, { \"internalType\": \"bytes32\", \"name\": \"s\", \"type\": \"bytes32\" } ], \"internalType\": \"struct OrchidLottery1.Ticket[]\", \"name\": \"tickets\", \"type\": \"tuple[]\" }, { \"internalType\": \"bytes32[]\", \"name\": \"refunds\", \"type\": \"bytes32[]\" } ], \"name\": \"claim\", \"outputs\": [], \"stateMutability\": \"nonpayable\", \"type\": \"function\" }, { \"inputs\": [ { \"internalType\": \"contract IERC20\", \"name\": \"token\", \"type\": \"address\" }, { \"internalType\": \"uint256\", \"name\": \"amount\", \"type\": \"uint256\" }, { \"internalType\": \"address\", \"name\": \"signer\", \"type\": \"address\" }, { \"internalType\": \"int256\", \"name\": \"adjust\", \"type\": \"int256\" }, { \"internalType\": \"int256\", \"name\": \"warn\", \"type\": \"int256\" }, { \"internalType\": \"uint256\", \"name\": \"retrieve\", \"type\": \"uint256\" } ], \"name\": \"edit\", \"outputs\": [], \"stateMutability\": \"nonpayable\", \"type\": \"function\" }, { \"inputs\": [ { \"internalType\": \"address\", \"name\": \"signer\", \"type\": \"address\" }, { \"internalType\": \"int256\", \"name\": \"adjust\", \"type\": \"int256\" }, { \"internalType\": \"int256\", \"name\": \"warn\", \"type\": \"int256\" }, { \"internalType\": \"uint256\", \"name\": \"retrieve\", \"type\": \"uint256\" } ], \"name\": \"edit\", \"outputs\": [], \"stateMutability\": \"payable\", \"type\": \"function\" }, { \"inputs\": [ { \"internalType\": \"bool\", \"name\": \"cancel\", \"type\": \"bool\" }, { \"internalType\": \"address[]\", \"name\": \"recipients\", \"type\": \"address[]\" } ], \"name\": \"enroll\", \"outputs\": [], \"stateMutability\": \"nonpayable\", \"type\": \"function\" }, { \"inputs\": [ { \"internalType\": \"address\", \"name\": \"funder\", \"type\": \"address\" }, { \"internalType\": \"address\", \"name\": \"recipient\", \"type\": \"address\" } ], \"name\": \"enrolled\", \"outputs\": [ { \"internalType\": \"uint256\", \"name\": \"\", \"type\": \"uint256\" } ], \"stateMutability\": \"view\", \"type\": \"function\" }, { \"inputs\": [ { \"internalType\": \"contract IERC20\", \"name\": \"token\", \"type\": \"address\" }, { \"internalType\": \"address\", \"name\": \"signer\", \"type\": \"address\" }, { \"internalType\": \"uint64\", \"name\": \"marked\", \"type\": \"uint64\" } ], \"name\": \"mark\", \"outputs\": [], \"stateMutability\": \"nonpayable\", \"type\": \"function\" }, { \"inputs\": [ { \"internalType\": \"address\", \"name\": \"sender\", \"type\": \"address\" }, { \"internalType\": \"uint256\", \"name\": \"amount\", \"type\": \"uint256\" }, { \"internalType\": \"bytes\", \"name\": \"data\", \"type\": \"bytes\" } ], \"name\": \"onTokenTransfer\", \"outputs\": [ { \"internalType\": \"bool\", \"name\": \"\", \"type\": \"bool\" } ], \"stateMutability\": \"nonpayable\", \"type\": \"function\" }, { \"inputs\": [ { \"internalType\": \"contract IERC20\", \"name\": \"token\", \"type\": \"address\" }, { \"internalType\": \"address\", \"name\": \"funder\", \"type\": \"address\" }, { \"internalType\": \"address\", \"name\": \"signer\", \"type\": \"address\" } ], \"name\": \"read\", \"outputs\": [ { \"internalType\": \"uint256\", \"name\": \"\", \"type\": \"uint256\" }, { \"internalType\": \"uint256\", \"name\": \"\", \"type\": \"uint256\" } ], \"stateMutability\": \"view\", \"type\": \"function\" }, { \"inputs\": [ { \"internalType\": \"uint256\", \"name\": \"count\", \"type\": \"uint256\" }, { \"internalType\": \"bytes32\", \"name\": \"seed\", \"type\": \"bytes32\" } ], \"name\": \"save\", \"outputs\": [], \"stateMutability\": \"nonpayable\", \"type\": \"function\" }, { \"inputs\": [ { \"internalType\": \"address\", \"name\": \"sender\", \"type\": \"address\" }, { \"internalType\": \"uint256\", \"name\": \"amount\", \"type\": \"uint256\" }, { \"internalType\": \"bytes\", \"name\": \"data\", \"type\": \"bytes\" } ], \"name\": \"tokenFallback\", \"outputs\": [], \"stateMutability\": \"nonpayable\", \"type\": \"function\" }]" - contract_abi= None - - contract= None - rpc_url= None - chain_id = None - gas_amount = None - - def __init__(self, rpc_url=rpc_url_default, chain_id=chain_id_default, gas_amount=gas_amount_default): - self.rpc_url = rpc_url - self.chain_id = chain_id - self.gas_amount = gas_amount - - self.contract_abi = json.loads(self.contract_abi_str) - - def init_contract(self, web3): - self.web3 = web3 - self.contract = self.web3.eth.contract(address=self.contract_addr, abi=self.contract_abi) - - - @staticmethod - def prepareTicket(tk:Ticket, reveal): - return [tk.data, to_32byte_hex(reveal), tk.packed0, tk.packed1, to_32byte_hex(tk.sig_r), to_32byte_hex(tk.sig_s)] - return [tk.data.hex(), reveal, tk.packed0, tk.packed1, tk.sig_r, tk.sig_s] - - # Ticket object, L1 address & key - def claim_ticket(self, ticket, recipient, executor_key, reveal): - tk = Lottery.prepareTicket(ticket, reveal) - executor_address = self.web3.eth.account.from_key(executor_key).address - l1nonce = self.web3.eth.get_transaction_count(executor_address) - func = self.contract.functions.claim(self.token, recipient, [tk], []) - - tx = func.build_transaction({ - 'chainId': self.chain_id, - 'gas': self.gas_amount, - 'maxFeePerGas': self.web3.to_wei('100', 'gwei'), - 'maxPriorityFeePerGas': self.web3.to_wei('40', 'gwei'), - 'nonce': l1nonce - }) - - # Polygon Estimates - # if (self.chain_id == 137): - # gas_estimate = self.web3.eth.estimate_gas(tx) - # print("gas ", gas_estimate) - # tx.update({'gas': gas_estimate}) - - signed = self.web3.eth.account.sign_transaction(tx, private_key=executor_key) - txhash = self.web3.eth.send_raw_transaction(signed.rawTransaction) - return txhash.hex() - - def check_balance(self, addressL1, addressL2): - escrow_amount = self.contract.functions.read(self.token, addressL1, addressL2).call(block_identifier='latest')[0] - balance = float(escrow_amount & uint128) / pow(10,18) - escrow = float(escrow_amount >> 128) / pow(10,18) - return balance, escrow \ No newline at end of file +import web3 +from web3 import Web3 +from typing import Tuple, Optional, List +import json +from eth_abi.packed import encode_packed +from ticket import Ticket +import os + +class LotteryError(Exception): + """Raised when lottery operations fail""" + pass + +class Lottery: + V1_ADDR = "0x6dB8381b2B41b74E17F5D4eB82E8d5b04ddA0a82" # v1 on all chains + V0_CHAIN_ID = 1 + V0_TOKEN = "0x4575f41308EC1483f3d399aa9a2826d74Da13Deb" # OXT + V0_ADDR = "0xb02396f06CC894834b7934ecF8c8E5Ab5C1d12F1" + + WEI = 10**18 + UINT128_MAX = (1 << 128) - 1 + UINT64_MAX = (1 << 64) - 1 + + def __init__(self, + web3_provider: Web3, + chain_id: int = 100, + addr: str = None, + gas_amount: int = 100000): + self.web3 = web3_provider + self.chain_id = chain_id + self.contract_addr = addr or self.V1_ADDR + self.gas_amount = gas_amount + self.version = self._detect_version() + self.contract = None + self.init_contract() + + def _detect_version(self) -> int: + """Determine if this is a v0 or v1 lottery""" + if (self.chain_id == self.V0_CHAIN_ID and + self.contract_addr.lower() == self.V0_ADDR.lower()): + return 0 + if self.contract_addr.lower() != self.V1_ADDR.lower(): + raise LotteryError(f"Unknown lottery contract address: {self.contract_addr}") + return 1 + + def init_contract(self): + try: + if self.version == 0: + abi = self._load_contract_abi("lottery0.abi") + else: + abi = self._load_contract_abi("lottery1.abi") + + self.contract = self.web3.eth.contract( + address=self.contract_addr, + abi=abi + ) + except Exception as e: + raise LotteryError(f"Failed to initialize contract: {e}") + + def _load_contract_abi(self, filename: str) -> dict: + try: + module_dir = os.path.dirname(os.path.abspath(__file__)) + abi_path = os.path.join(module_dir, filename) + + with open(abi_path, 'r') as f: + return json.load(f) + except Exception as e: + raise LotteryError(f"Failed to load contract ABI from {abi_path}: {e}") + + async def check_balance(self, + token_addr: str, + funder: str, + signer: str) -> Tuple[int, int]: + try: + funder = self.web3.to_checksum_address(funder) + signer = self.web3.to_checksum_address(signer) + token_addr = self.web3.to_checksum_address(token_addr) + + if self.version == 0: + if token_addr.lower() != self.V0_TOKEN.lower(): + raise LotteryError("V0 lottery only supports OXT token") + escrow_amount, unlock_warned = await self.contract.functions.look( + funder, + signer + ).call() + else: + escrow_amount, unlock_warned = await self.contract.functions.read( + token_addr, + funder, + signer + ).call() + + balance = escrow_amount & self.UINT128_MAX + escrow = escrow_amount >> 128 + return balance, escrow + + except Exception as e: + raise LotteryError(f"Failed to check balance: {e}") + + def claim_tickets(self, + recipient: str, + tickets: List[Ticket], + executor_key: str, + token_addr: str = "0x0000000000000000000000000000000000000000" + ) -> str: + try: + recipient = self.web3.to_checksum_address(recipient) + token_addr = self.web3.to_checksum_address(token_addr) + + if self.version == 0 and token_addr.lower() != self.V0_TOKEN.lower(): + raise LotteryError("V0 lottery only supports OXT token") + + executor_address = self.web3.eth.account.from_key(executor_key).address + nonce = self.web3.eth.get_transaction_count(executor_address) + + prepared_tickets = [ + self._prepare_ticket(ticket, ticket.reveal) + for ticket in tickets + ] + + func = self.contract.functions.claim( + token_addr, + recipient, + prepared_tickets, + [] # Empty refunds array + ) + + tx = func.build_transaction({ + 'chainId': self.chain_id, + 'gas': self.gas_amount, + 'maxFeePerGas': self.web3.to_wei('100', 'gwei'), + 'maxPriorityFeePerGas': self.web3.to_wei('40', 'gwei'), + 'nonce': nonce + }) + + signed = self.web3.eth.account.sign_transaction( + tx, + private_key=executor_key + ) + tx_hash = self.web3.eth.send_raw_transaction(signed.rawTransaction) + return tx_hash.hex() + + except Exception as e: + raise LotteryError(f"Failed to claim tickets: {e}") + + def _prepare_ticket(self, ticket: Ticket, reveal: str) -> list: + return [ + ticket.data, + Web3.to_bytes(hexstr=reveal), + ticket.packed0, + ticket.packed1, + Web3.to_bytes(hexstr=ticket.sig_r), + Web3.to_bytes(hexstr=ticket.sig_s) + ] + + @staticmethod + def wei_to_token(wei_amount: int) -> float: + return wei_amount / Lottery.WEI + + @staticmethod + def token_to_wei(token_amount: float) -> int: + return int(token_amount * Lottery.WEI) diff --git a/gai-backend/lottery0.abi b/gai-backend/lottery0.abi new file mode 100644 index 000000000..6b8e78799 --- /dev/null +++ b/gai-backend/lottery0.abi @@ -0,0 +1 @@ +[{"inputs":[{"internalType":"contract IERC20","name":"token","type":"address"}],"payable":false,"stateMutability":"nonpayable","type":"constructor"},{"anonymous":false,"inputs":[{"indexed":true,"internalType":"address","name":"funder","type":"address"},{"indexed":true,"internalType":"address","name":"signer","type":"address"}],"name":"Bound","type":"event"},{"anonymous":false,"inputs":[{"indexed":true,"internalType":"address","name":"funder","type":"address"},{"indexed":true,"internalType":"address","name":"signer","type":"address"}],"name":"Create","type":"event"},{"anonymous":false,"inputs":[{"indexed":true,"internalType":"address","name":"funder","type":"address"},{"indexed":true,"internalType":"address","name":"signer","type":"address"},{"indexed":false,"internalType":"uint128","name":"amount","type":"uint128"},{"indexed":false,"internalType":"uint128","name":"escrow","type":"uint128"},{"indexed":false,"internalType":"uint256","name":"unlock","type":"uint256"}],"name":"Update","type":"event"},{"constant":false,"inputs":[{"internalType":"address","name":"signer","type":"address"},{"internalType":"contract OrchidVerifier","name":"verify","type":"address"},{"internalType":"bytes","name":"shared","type":"bytes"}],"name":"bind","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":false,"inputs":[{"internalType":"address","name":"signer","type":"address"},{"internalType":"uint128","name":"escrow","type":"uint128"}],"name":"burn","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":false,"inputs":[{"internalType":"address","name":"funder","type":"address"},{"internalType":"address payable","name":"recipient","type":"address"},{"internalType":"uint128","name":"amount","type":"uint128"},{"internalType":"bytes","name":"receipt","type":"bytes"}],"name":"give","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":false,"inputs":[{"internalType":"bytes32","name":"reveal","type":"bytes32"},{"internalType":"bytes32","name":"commit","type":"bytes32"},{"internalType":"uint256","name":"issued","type":"uint256"},{"internalType":"bytes32","name":"nonce","type":"bytes32"},{"internalType":"uint8","name":"v","type":"uint8"},{"internalType":"bytes32","name":"r","type":"bytes32"},{"internalType":"bytes32","name":"s","type":"bytes32"},{"internalType":"uint128","name":"amount","type":"uint128"},{"internalType":"uint128","name":"ratio","type":"uint128"},{"internalType":"uint256","name":"start","type":"uint256"},{"internalType":"uint128","name":"range","type":"uint128"},{"internalType":"address","name":"funder","type":"address"},{"internalType":"address payable","name":"recipient","type":"address"},{"internalType":"bytes","name":"receipt","type":"bytes"},{"internalType":"bytes32[]","name":"old","type":"bytes32[]"}],"name":"grab","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":true,"inputs":[{"internalType":"address","name":"funder","type":"address"}],"name":"keys","outputs":[{"internalType":"address[]","name":"","type":"address[]"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":false,"inputs":[{"internalType":"address","name":"signer","type":"address"}],"name":"kill","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":false,"inputs":[{"internalType":"address","name":"signer","type":"address"}],"name":"lock","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":true,"inputs":[{"internalType":"address","name":"funder","type":"address"},{"internalType":"address","name":"signer","type":"address"}],"name":"look","outputs":[{"internalType":"uint128","name":"","type":"uint128"},{"internalType":"uint128","name":"","type":"uint128"},{"internalType":"uint256","name":"","type":"uint256"},{"internalType":"contract OrchidVerifier","name":"","type":"address"},{"internalType":"bytes32","name":"","type":"bytes32"},{"internalType":"bytes","name":"","type":"bytes"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":false,"inputs":[{"internalType":"address","name":"signer","type":"address"},{"internalType":"uint128","name":"amount","type":"uint128"}],"name":"move","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":true,"inputs":[{"internalType":"address","name":"funder","type":"address"},{"internalType":"uint256","name":"offset","type":"uint256"},{"internalType":"uint256","name":"count","type":"uint256"}],"name":"page","outputs":[{"internalType":"address[]","name":"","type":"address[]"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":false,"inputs":[{"internalType":"address","name":"signer","type":"address"},{"internalType":"address payable","name":"target","type":"address"},{"internalType":"bool","name":"autolock","type":"bool"},{"internalType":"uint128","name":"amount","type":"uint128"},{"internalType":"uint128","name":"escrow","type":"uint128"}],"name":"pull","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":false,"inputs":[{"internalType":"address","name":"signer","type":"address"},{"internalType":"uint128","name":"total","type":"uint128"},{"internalType":"uint128","name":"escrow","type":"uint128"}],"name":"push","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":true,"inputs":[{"internalType":"address","name":"funder","type":"address"},{"internalType":"uint256","name":"offset","type":"uint256"}],"name":"seek","outputs":[{"internalType":"address","name":"","type":"address"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":true,"inputs":[{"internalType":"address","name":"funder","type":"address"}],"name":"size","outputs":[{"internalType":"uint256","name":"","type":"uint256"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":false,"inputs":[{"internalType":"address","name":"signer","type":"address"}],"name":"warn","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":true,"inputs":[],"name":"what","outputs":[{"internalType":"contract IERC20","name":"","type":"address"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":false,"inputs":[{"internalType":"address","name":"signer","type":"address"},{"internalType":"address payable","name":"target","type":"address"},{"internalType":"bool","name":"autolock","type":"bool"}],"name":"yank","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"}] diff --git a/gai-backend/lottery1.abi b/gai-backend/lottery1.abi new file mode 100644 index 000000000..3306342f4 --- /dev/null +++ b/gai-backend/lottery1.abi @@ -0,0 +1 @@ +[{"inputs":[{"internalType":"uint64","name":"day","type":"uint64"}],"stateMutability":"nonpayable","type":"constructor"},{"anonymous":false,"inputs":[{"indexed":true,"internalType":"contract IERC20","name":"token","type":"address"},{"indexed":true,"internalType":"address","name":"funder","type":"address"},{"indexed":true,"internalType":"address","name":"signer","type":"address"}],"name":"Create","type":"event"},{"anonymous":false,"inputs":[{"indexed":true,"internalType":"bytes32","name":"key","type":"bytes32"},{"indexed":false,"internalType":"uint256","name":"unlock_warned","type":"uint256"}],"name":"Delete","type":"event"},{"anonymous":false,"inputs":[{"indexed":true,"internalType":"address","name":"funder","type":"address"},{"indexed":true,"internalType":"address","name":"recipient","type":"address"}],"name":"Enroll","type":"event"},{"anonymous":false,"inputs":[{"indexed":true,"internalType":"bytes32","name":"key","type":"bytes32"},{"indexed":false,"internalType":"uint256","name":"escrow_amount","type":"uint256"}],"name":"Update","type":"event"},{"inputs":[{"internalType":"contract IERC20","name":"token","type":"address"},{"internalType":"address","name":"recipient","type":"address"},{"components":[{"internalType":"bytes32","name":"data","type":"bytes32"},{"internalType":"bytes32","name":"reveal","type":"bytes32"},{"internalType":"uint256","name":"packed0","type":"uint256"},{"internalType":"uint256","name":"packed1","type":"uint256"},{"internalType":"bytes32","name":"r","type":"bytes32"},{"internalType":"bytes32","name":"s","type":"bytes32"}],"internalType":"struct OrchidLottery1.Ticket[]","name":"tickets","type":"tuple[]"},{"internalType":"bytes32[]","name":"refunds","type":"bytes32[]"}],"name":"claim","outputs":[],"stateMutability":"nonpayable","type":"function"},{"inputs":[{"internalType":"contract IERC20","name":"token","type":"address"},{"internalType":"uint256","name":"amount","type":"uint256"},{"internalType":"address","name":"signer","type":"address"},{"internalType":"int256","name":"adjust","type":"int256"},{"internalType":"int256","name":"warn","type":"int256"},{"internalType":"uint256","name":"retrieve","type":"uint256"}],"name":"edit","outputs":[],"stateMutability":"nonpayable","type":"function"},{"inputs":[{"internalType":"address","name":"signer","type":"address"},{"internalType":"int256","name":"adjust","type":"int256"},{"internalType":"int256","name":"warn","type":"int256"},{"internalType":"uint256","name":"retrieve","type":"uint256"}],"name":"edit","outputs":[],"stateMutability":"payable","type":"function"},{"inputs":[{"internalType":"bool","name":"cancel","type":"bool"},{"internalType":"address[]","name":"recipients","type":"address[]"}],"name":"enroll","outputs":[],"stateMutability":"nonpayable","type":"function"},{"inputs":[{"internalType":"address","name":"funder","type":"address"},{"internalType":"address","name":"recipient","type":"address"}],"name":"enrolled","outputs":[{"internalType":"uint256","name":"","type":"uint256"}],"stateMutability":"view","type":"function"},{"inputs":[{"internalType":"contract IERC20","name":"token","type":"address"},{"internalType":"address","name":"signer","type":"address"},{"internalType":"uint64","name":"marked","type":"uint64"}],"name":"mark","outputs":[],"stateMutability":"nonpayable","type":"function"},{"inputs":[{"internalType":"address","name":"sender","type":"address"},{"internalType":"uint256","name":"amount","type":"uint256"},{"internalType":"bytes","name":"data","type":"bytes"}],"name":"onTokenTransfer","outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"nonpayable","type":"function"},{"inputs":[{"internalType":"contract IERC20","name":"token","type":"address"},{"internalType":"address","name":"funder","type":"address"},{"internalType":"address","name":"signer","type":"address"}],"name":"read","outputs":[{"internalType":"uint256","name":"","type":"uint256"},{"internalType":"uint256","name":"","type":"uint256"}],"stateMutability":"view","type":"function"},{"inputs":[{"internalType":"uint256","name":"count","type":"uint256"},{"internalType":"bytes32","name":"seed","type":"bytes32"}],"name":"save","outputs":[],"stateMutability":"nonpayable","type":"function"},{"inputs":[{"internalType":"address","name":"sender","type":"address"},{"internalType":"uint256","name":"amount","type":"uint256"},{"internalType":"bytes","name":"data","type":"bytes"}],"name":"tokenFallback","outputs":[],"stateMutability":"nonpayable","type":"function"}] diff --git a/gai-backend/payment_handler.py b/gai-backend/payment_handler.py new file mode 100644 index 000000000..7d8196387 --- /dev/null +++ b/gai-backend/payment_handler.py @@ -0,0 +1,74 @@ +import web3 +from decimal import Decimal +import random +import ethereum +from typing import Tuple, Optional +import json +import sys +import traceback +import logging + +from ticket import Ticket +from lottery import Lottery + +logger = logging.getLogger(__name__) + +wei = pow(10, 18) + +class PaymentError(Exception): + """Base class for payment processing errors""" + pass + +class PaymentHandler: + def __init__(self, lottery_address: str, recipient_key: str, rpc_url: str = 'https://rpc.gnosischain.com/'): + self.lottery_address = lottery_address + self.recipient_key = recipient_key + self.w3 = web3.Web3(web3.Web3.HTTPProvider(rpc_url)) + self.recipient_addr = web3.Account.from_key(recipient_key).address + self.lottery = Lottery(self.w3) + + def new_reveal(self) -> Tuple[str, str]: + num = hex(random.randrange(pow(2,256)))[2:] + reveal = '0x' + num[2:].zfill(64) + try: + commit = ethereum.utils.sha3(bytes.fromhex(reveal[2:])).hex() + return reveal, commit + except Exception as e: + logger.error(f"Failed to generate reveal/commit pair: {e}") + raise PaymentError("Failed to generate payment credentials") + + def create_invoice(self, amount: float, commit: str) -> str: + return json.dumps({ + 'type': 'invoice', + 'amount': int(wei * amount), + 'commit': '0x' + str(commit), + 'recipient': self.recipient_addr + }) + + async def process_ticket(self, ticket_data: str, reveal: str, commit: str) -> Tuple[float, str, str]: + try: + ticket = Ticket.deserialize( + ticket_data, + reveal=reveal, + commitment=commit, + recipient=self.recipient_addr, + lottery_addr=self.lottery_address + ) + + if ticket.is_winner(): + logger.info( + f"Winner found! Face value: {ticket.face_value() / wei}, " + "Adding to claim queue (stubbed)" + ) + + new_reveal, new_commit = self.new_reveal() + return ticket.face_value() / wei, new_reveal, new_commit + + except Exception as e: + logger.error("Failed to process ticket") + logger.error(traceback.format_exc()) + raise PaymentError(f"Ticket processing failed: {e}") + + async def queue_claim(self, ticket: Ticket): + logger.info(f"Queued ticket claim for {ticket.face_value() / wei} tokens (stubbed)") + pass diff --git a/gai-backend/server.py b/gai-backend/server.py index c73633e3e..777a94cff 100644 --- a/gai-backend/server.py +++ b/gai-backend/server.py @@ -4,163 +4,252 @@ import json import hashlib import random - -import web3 -import ethereum - -import billing -import jobs -import ticket -import lottery +from redis.asyncio import Redis +import redis +from decimal import Decimal +import uuid +import time import os -import traceback import sys +import traceback +from typing import Optional, Dict -uint256 = pow(2,256) - 1 -uint64 = pow(2,64) - 1 -wei = pow(10, 18) - -prices = { - 'invoice': 0.0001, - 'payment': 0.0001, - 'connection': 0.0001, - 'error': 0.0001, - 'job': 0.01, - 'complete': 0.001, - 'started': 0.0001 -} - -lottery_address = '0x6dB8381b2B41b74E17F5D4eB82E8d5b04ddA0a82' +import billing +from config_manager import ConfigManager, ConfigError +from payment_handler import PaymentHandler, PaymentError -internal_messages = ['charge'] +# Configuration +LOTTERY_ADDRESS = '0x6dB8381b2B41b74E17F5D4eB82E8d5b04ddA0a82' disconnect_threshold = -25 -def invoice(amt, commit, recipient): - return json.dumps({'type': 'invoice', 'amount': int(pow(10,18) * amt), 'commit': '0x' + str(commit), 'recipient': recipient}) - -def process_tickets(tix, recip, reveal, commit, lotto, key): - try: -# print(f'Got ticket: {tix[0]}') - tk = ticket.Ticket.deserialize_ticket(tix[0], reveal, commit, recip, lotaddr=lottery_address) -# tk.print_ticket() - if tk.is_winner(reveal): - hash = lotto.claim_ticket(tk, recip, key, reveal) - print(f"Claim tx: {hash}") - reveal, commit = new_reveal() - tk.print_ticket() - return tk.value() / wei, reveal, commit - except Exception: - print('process_ticket() failed') - exc_type, exc_value, exc_traceback = sys.exc_info() - traceback.print_exception(exc_type, exc_value, exc_traceback, limit=20, file=sys.stdout) - return 0, reveal, commit - async def send_error(ws, code): await ws.send(json.dumps({'type': 'error', 'code': code})) -def new_reveal(): - num = hex(random.randrange(pow(2,256)))[2:] - reveal = '0x' + num[2:].zfill(64) -# print(f'new_reveal: {reveal}') - try: - commit = ethereum.utils.sha3(bytes.fromhex(reveal[2:])).hex() - except: - exc_type, exc_value, exc_traceback = sys.exc_info() - traceback.print_exception(exc_type, exc_value, exc_traceback, limit=20, file=sys.stdout) - return reveal, commit +class BalanceMonitor: + def __init__(self, redis: Redis, bills: billing.StrictRedisBilling): + self.redis = redis + self.bills = bills + self._monitors = {} + self.pubsub = self.redis.pubsub() + + def _get_channel(self, client_id: str) -> str: + return f"billing:balance:updates:{client_id}" + + async def start_monitoring(self, client_id: str, websocket, payment_handler: PaymentHandler, commit: str): + if client_id in self._monitors: + await self.stop_monitoring(client_id) + + channel = self._get_channel(client_id) + await self.pubsub.subscribe(channel) + + self._monitors[client_id] = asyncio.create_task( + self._monitor_balance(client_id, channel, websocket, payment_handler, commit) + ) + + async def stop_monitoring(self, client_id: str): + if client_id in self._monitors: + channel = self._get_channel(client_id) + await self.pubsub.unsubscribe(channel) + self._monitors[client_id].cancel() + try: + await self._monitors[client_id] + except asyncio.CancelledError: + pass + del self._monitors[client_id] + + async def _monitor_balance(self, client_id: str, channel: str, websocket, payment_handler: PaymentHandler, commit: str): + try: + last_invoice_time = 0 + MIN_INVOICE_INTERVAL = 1.0 # Minimum seconds between invoices + + while True: + message = await self.pubsub.get_message(ignore_subscribe_messages=True) + if message is None: + await asyncio.sleep(0.01) + continue + + try: + # Wait for in-flight payments to process + await asyncio.sleep(0.1) + + current_time = time.time() + if current_time - last_invoice_time < MIN_INVOICE_INTERVAL: + continue + + balance = await self.bills.balance(client_id) + min_balance = await self.bills.min_balance() + + if balance < min_balance: + await self.bills.debit(client_id, type='invoice') + invoice_amount = 2 * min_balance - balance + await websocket.send( + payment_handler.create_invoice(invoice_amount, commit) + ) + last_invoice_time = current_time + + except Exception as e: + print(f"Error processing balance update for {client_id}: {e}") + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Balance monitor error for {client_id}: {e}") -async def session(websocket, bills=None, job=None, recipient='0x0', key=''): +async def session( + websocket, + bills=None, + payment_handler=None, + config_manager=None +): print("New client connection") - reserve_price = 0.00006 - lotto = lottery.Lottery() - w3 = web3.Web3(web3.Web3.HTTPProvider('https://rpc.gnosischain.com/')) - lotto.init_contract(w3) - id = websocket.id - bills.debit(id, type='invoice') - send_queue, recv_queue = job.get_queues(id) - reveal, commit = new_reveal() - await websocket.send(invoice(2 * bills.min_balance(), commit, recipient)) - sources = [websocket.recv, recv_queue.get] - tasks = [None, None] - while True: - if bills.balance(id) < disconnect_threshold: - await websocket.close(reason='Balance too low') - break + try: + id = websocket.id + balance_monitor = BalanceMonitor(bills.redis, bills) + inference_url = None + + if config_manager: + config = await config_manager.load_config() + inference_url = config.get('inference', {}).get('api_url') + if not inference_url: + print("No inference URL configured") + await websocket.close(reason='Configuration error') + return + + await bills.debit(id, type='invoice') + reveal, commit = payment_handler.new_reveal() + await websocket.send( + payment_handler.create_invoice(2 * await bills.min_balance(), commit) + ) + + await balance_monitor.start_monitoring(id, websocket, payment_handler, commit) + try: - for i in range(2): - if tasks[i] is None: - tasks[i] = asyncio.create_task(sources[i]()) - done, pending = await asyncio.wait(tasks, return_when = asyncio.FIRST_COMPLETED) - for i, task in enumerate(tasks): - if task in done: - tasks[i] = None - for task in done: - message_ = task.result() - message = json.loads(message_) - if message['type'] == 'payment': + while True: + message = await websocket.recv() + try: + msg = json.loads(message) + except json.JSONDecodeError: + print(f"Failed to parse message: {message}") + continue + + if msg['type'] == 'request_token': try: - amt, reveal, commit = process_tickets(message['tickets'], recipient, reveal, commit, lotto, key) - print(f'Got ticket worth {amt}') - bills.credit(id, amount=amt) - except: - print('outer failure in processing payment') - exc_type, exc_value, exc_traceback = sys.exc_info() - traceback.print_tb(exc_traceback, limit=1, file=sys.stdout) - bills.debit(id, type='error') - await send_error(websocket, -6001) - if bills.balance(id) < bills.min_balance(): - bills.debit(id, type='invoice') - await websocket.send(invoice(2 * bills.min_balance() - bills.balance(id), commit, recipient)) - if message['type'] not in internal_messages: - bills.debit(id, type=message['type']) - if message['type'] == 'job': - jid = hashlib.sha256(bytes(message['prompt'], 'utf-8')).hexdigest() - if reserve_price != 0 and float(message['bid']) < reserve_price: - await websocket.send(json.dumps({'type': 'bid_low'})) - continue - await job.add_job(id, message['bid'], - {'id': jid, 'prompt': message['prompt']}) - if message['type'] == 'charge': + await bills.debit(id, type='auth_token') + print(f"Using inference URL: {inference_url}") + await websocket.send(json.dumps({ + 'type': 'auth_token', + 'session_id': str(id), + 'inference_url': inference_url + })) + except billing.BillingError as e: + print(f"Auth token billing failed: {e}") + await send_error(websocket, -6002) + continue + except Exception as e: + print(f"Auth token error: {e}") + await send_error(websocket, -6002) + continue + + elif msg['type'] == 'payment': try: - bills.debit(id, amount=message['amount']) - await send_queue.put(True) - except: - print('exception in charge handler') - if message['type'] == 'complete': - await websocket.send(json.dumps({'type': 'job_complete', "output": message['response'], - 'model': message['model'], 'reason': message['reason'], - 'usage': message['usage']})) - if message['type'] == 'started': - await websocket.send(json.dumps({'type': 'job_started'})) - except (websockets.exceptions.ConnectionClosedOK, websockets.exceptions.ConnectionClosedError): - print('connection closed') - break - + amount, reveal, commit = await payment_handler.process_ticket( + msg['tickets'][0], reveal, commit + ) + print(f'Got ticket worth {amount}') + await bills.credit(id, amount=amount) + except PaymentError as e: + print(f'Payment processing failed: {e}') + await bills.debit(id, type='error') + await send_error(websocket, -6001) + continue + except Exception as e: + print(f'Unexpected payment error: {e}') + await bills.debit(id, type='error') + await send_error(websocket, -6001) + continue + + except websockets.exceptions.ConnectionClosed: + print('Connection closed normally') + except Exception as e: + print(f"Error processing message: {e}") + await websocket.close(reason='Internal server error') + finally: + await balance_monitor.stop_monitoring(id) + + except Exception as e: + print(f"Fatal error in session: {e}") + await websocket.close(reason='Internal server error') + +async def main(bind_addr, bind_port, recipient_key, redis_url, config_path: Optional[str] = None): + redis = Redis.from_url(redis_url, decode_responses=True) + + try: + config_manager = ConfigManager(redis) + config = await config_manager.load_config(config_path) + except ConfigError as e: + print(f"Configuration error: {e}") + return + except Exception as e: + print(f"Unexpected error loading config: {e}") + return + + try: + bills = billing.StrictRedisBilling(redis) + await bills.init() + except billing.BillingError as e: + print(f"Billing initialization error: {e}") + return + except Exception as e: + print(f"Unexpected error initializing billing: {e}") + return + + payment_handler = PaymentHandler(LOTTERY_ADDRESS, recipient_key) -async def main(model, url, bind_addr, bind_port, recipient_key, llmkey, llmparams, api): - recipient_addr = web3.Account.from_key(recipient_key).address - bills = billing.Billing(prices) - job = jobs.jobs(model, url, llmkey, llmparams, api) print("\n*****") print(f"* Server starting up at {bind_addr} {bind_port}") - print(f"* Connecting to back end at {url}") - print(f"* With model {model}") - print(f"* Using wallet at {recipient_addr}") + print(f"* Using wallet at {payment_handler.recipient_addr}") + print(f"* Connected to Redis at {redis_url}") print("******\n\n") - async with websockets.serve(functools.partial(session, bills=bills, job=job, - recipient=recipient_addr, key=recipient_key), - bind_addr, bind_port): - await asyncio.wait([asyncio.create_task(job.process_jobs())]) + + async with websockets.serve( + functools.partial( + session, + bills=bills, + payment_handler=payment_handler, + config_manager=config_manager + ), + bind_addr, + bind_port + ): + await asyncio.Future() # Run forever if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='Start the billing server') + parser.add_argument('--config', type=str, help='Path to config file (optional)') + + args = parser.parse_args() + + required_env = { + 'ORCHID_GENAI_ADDR': "Bind address", + 'ORCHID_GENAI_PORT': "Bind port", + 'ORCHID_GENAI_RECIPIENT_KEY': "Recipient key", + 'ORCHID_GENAI_REDIS_URL': "Redis connection URL", + } + + # Check required environment variables + missing = [name for name in required_env if name not in os.environ] + if missing: + print("Missing required environment variables:") + for name in missing: + print(f" {name}: {required_env[name]}") + sys.exit(1) + bind_addr = os.environ['ORCHID_GENAI_ADDR'] bind_port = os.environ['ORCHID_GENAI_PORT'] recipient_key = os.environ['ORCHID_GENAI_RECIPIENT_KEY'] - url = os.environ['ORCHID_GENAI_LLM_URL'] - model = os.environ['ORCHID_GENAI_LLM_MODEL'] - api = 'openai' if 'ORCHID_GENAI_API_TYPE' not in os.environ else os.environ['ORCHID_GENAI_API_TYPE'] - llmkey = None if 'ORCHID_GENAI_LLM_AUTH_KEY' not in os.environ else os.environ['ORCHID_GENAI_LLM_AUTH_KEY'] - llmparams = {} - if 'ORCHID_GENAI_LLM_PARAMS' in os.environ: - llmparams = json.loads(os.environ['ORCHID_GENAI_LLM_PARAMS']) - asyncio.run(main(model, url, bind_addr, bind_port, recipient_key, llmkey, llmparams, api)) + redis_url = os.environ['ORCHID_GENAI_REDIS_URL'] + + asyncio.run(main(bind_addr, bind_port, recipient_key, redis_url, args.config)) + diff --git a/gai-backend/test.py b/gai-backend/test.py new file mode 100644 index 000000000..469f75084 --- /dev/null +++ b/gai-backend/test.py @@ -0,0 +1,307 @@ +import asyncio +import json +import logging +import os +import secrets +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, List +from web3 import Web3 +import websockets +import aiohttp +import time + +from account import OrchidAccount +from lottery import Lottery +from ticket import Ticket + +@dataclass +class InferenceConfig: + provider: str + funder: str + secret: str + chainid: int + currency: str + rpc: str + +@dataclass +class ProviderLocation: + billing_url: str + +@dataclass +class LocationConfig: + providers: Dict[str, ProviderLocation] + + @classmethod + def from_dict(cls, data: Dict) -> 'LocationConfig': + providers = { + k: ProviderLocation(**v) for k, v in data.items() + } + return cls(providers=providers) + +@dataclass +class Message: + role: str + content: str + name: Optional[str] = None + +@dataclass +class TestConfig: + messages: List[Message] + model: str + params: Dict + retry_delay: float = 1.5 + + @classmethod + def from_dict(cls, data: Dict) -> 'TestConfig': + messages = [] + if 'prompt' in data: + # Handle legacy config with single prompt + messages = [Message(role="user", content=data['prompt'])] + elif 'messages' in data: + messages = [Message(**msg) for msg in data['messages']] + else: + raise ValueError("Config must contain either 'prompt' or 'messages'") + + return cls( + messages=messages, + model=data['model'], + params=data.get('params', {}), + retry_delay=data.get('retry_delay', 1.5) + ) + +@dataclass +class LoggingConfig: + level: str + file: Optional[str] + +@dataclass +class ClientConfig: + inference: InferenceConfig + location: LocationConfig + test: TestConfig + logging: LoggingConfig + + @classmethod + def from_file(cls, config_path: str) -> 'ClientConfig': + with open(config_path) as f: + data = json.load(f) + return cls( + inference=InferenceConfig(**data['inference']), + location=LocationConfig.from_dict(data['location']), + test=TestConfig.from_dict(data['test']), + logging=LoggingConfig(**data['logging']) + ) + +class OrchidLLMTestClient: + def __init__(self, config_path: str, prompt: Optional[str] = None): + self.config = ClientConfig.from_file(config_path) + if prompt: + self.config.test.messages = [Message(role="user", content=prompt)] + + self._setup_logging() + self.logger = logging.getLogger(__name__) + + self.web3 = Web3(Web3.HTTPProvider(self.config.inference.rpc)) + self.lottery = Lottery( + self.web3, + chain_id=self.config.inference.chainid + ) + self.account = OrchidAccount( + self.lottery, + self.config.inference.funder, + self.config.inference.secret + ) + + self.ws = None + self.session_id = None + self.inference_url = None + self.message_queue = asyncio.Queue() + self._handler_task = None + + def _setup_logging(self): + logging.basicConfig( + level=getattr(logging, self.config.logging.level.upper()), + filename=self.config.logging.file, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + async def _handle_invoice(self, invoice_data: Dict) -> None: + try: + amount = int(invoice_data['amount']) + recipient = invoice_data['recipient'] + commit = invoice_data['commit'] + + ticket_str = self.account.create_ticket( + amount=amount, + recipient=recipient, + commitment=commit + ) + + await self.ws.send(json.dumps({ + 'type': 'payment', + 'tickets': [ticket_str] + })) + self.logger.info(f"Sent payment ticket for {amount/1e18} tokens") + + except Exception as e: + self.logger.error(f"Failed to handle invoice: {e}") + raise + + async def _billing_handler(self) -> None: + try: + async for message in self.ws: + msg = json.loads(message) + self.logger.debug(f"Received WS message: {msg['type']}") + + if msg['type'] == 'invoice': + await self._handle_invoice(msg) + elif msg['type'] == 'auth_token': + self.session_id = msg['session_id'] + self.inference_url = msg['inference_url'] + await self.message_queue.put(('auth_received', self.session_id)) + elif msg['type'] == 'error': + await self.message_queue.put(('error', msg['code'])) + + except websockets.exceptions.ConnectionClosed: + self.logger.info("Billing WebSocket closed") + except Exception as e: + self.logger.error(f"Billing handler error: {e}") + await self.message_queue.put(('error', str(e))) + + async def connect(self) -> None: + try: + provider = self.config.inference.provider + provider_config = self.config.location.providers.get(provider) + if not provider_config: + raise Exception(f"No configuration found for provider: {provider}") + + self.logger.info(f"Connecting to provider {provider} at {provider_config.billing_url}") + self.ws = await websockets.connect(provider_config.billing_url) + + self._handler_task = asyncio.create_task(self._billing_handler()) + + await self.ws.send(json.dumps({ + 'type': 'request_token', + 'orchid_account': self.config.inference.funder + })) + + msg_type, session_id = await self.message_queue.get() + if msg_type != 'auth_received': + raise Exception(f"Authentication failed: {session_id}") + + self.logger.info("Successfully authenticated") + + except Exception as e: + self.logger.error(f"Connection failed: {e}") + raise + + async def send_inference_request(self, retry_count: int = 0) -> Dict: + if not self.session_id: + raise Exception("Not authenticated") + + if not self.inference_url: + raise Exception("No inference URL received") + + try: + async with aiohttp.ClientSession() as session: + self.logger.debug(f"Using session ID: {self.session_id}") + headers = { + 'Authorization': f'Bearer {self.session_id}', + 'Content-Type': 'application/json' + } + + data = { + 'messages': [ + { + 'role': msg.role, + 'content': msg.content, + **(({'name': msg.name} if msg.name else {})) + } + for msg in self.config.test.messages + ], + 'model': self.config.test.model, + 'params': self.config.test.params + } + + self.logger.info(f"Sending inference request (attempt {retry_count + 1})") + self.logger.debug(f"Request URL: {self.inference_url}") + self.logger.debug(f"Request data: {data}") + + async with session.post( + self.inference_url, + headers=headers, + json=data, + timeout=30 + ) as response: + if response.status == 402: + retry_delay = self.config.test.retry_delay + self.logger.info(f"Insufficient balance, waiting {retry_delay}s for payment processing...") + await asyncio.sleep(retry_delay) + return await self.send_inference_request(retry_count + 1) + elif response.status == 401: + error_text = await response.text() + self.logger.error(f"Authentication failed: {error_text}") + raise Exception(f"Authentication failed: {error_text}") + elif response.status != 200: + error_text = await response.text() + self.logger.error(f"Inference request failed (status {response.status}): {error_text}") + raise Exception(f"Inference request failed: {error_text}") + + result = await response.json() + self.logger.info(f"Inference complete: {result['usage']} tokens used") + return result + + except Exception as e: + self.logger.error(f"Inference request failed: {e}") + raise + + async def close(self) -> None: + if self.ws: + try: + await self.ws.close() + self.logger.info("Connection closed") + except Exception as e: + self.logger.error(f"Error closing connection: {e}") + + if self._handler_task: + self._handler_task.cancel() + try: + await self._handler_task + except asyncio.CancelledError: + pass + + async def __aenter__(self): + """Async context manager support""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Ensure cleanup on context exit""" + await self.close() + +async def main(config_path: str, prompt: Optional[str] = None): + async with OrchidLLMTestClient(config_path, prompt) as client: + try: + await client.connect() + result = await client.send_inference_request() + + print("\nInference Results:") + messages = client.config.test.messages + print(f"Messages:") + for msg in messages: + print(f" {msg.role}: {msg.content}") + print(f"Response: {result['response']}") + print(f"Usage: {json.dumps(result['usage'], indent=2)}") + + except Exception as e: + print(f"Test failed: {e}") + raise + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("config", help="Path to config file") + parser.add_argument("prompt", nargs="*", help="Optional prompt to override config") + args = parser.parse_args() + + prompt = " ".join(args.prompt) if args.prompt else None + asyncio.run(main(args.config, prompt)) diff --git a/gai-backend/ticket.py b/gai-backend/ticket.py index f63c500ce..33cb98040 100644 --- a/gai-backend/ticket.py +++ b/gai-backend/ticket.py @@ -1,124 +1,103 @@ -import datetime -import web3 -from eth_abi.packed import encode_packed -import eth_account -import ethereum - -uint64 = pow(2, 64) - 1 # 18446744073709551615 -uint128 = pow(2, 128) - 1 # 340282366920938463463374607431768211455 -addrtype = pow(2, 20 * 8) - 1 - - -class Ticket: - def __init__(self, web3_provider, data, lot_addr, token_addr, amount, ratio, funder, recipient, commitment, key): - - self.web3 = web3_provider - self.data = data - self.lot_addr = lot_addr - self.token_addr = token_addr - self.amount = amount - self.ratio = ratio - self.commitment = commitment - self.funder = funder - self.recipient = recipient - self.key = key - - # Check if we have all the variables - if all(v is not None for v in [web3_provider, recipient, commitment, ratio, funder, amount, lot_addr, token_addr, key]): - issued = int(datetime.datetime.now().timestamp()) - l2nonce = int(web3.Web3.keccak(text=(f'{datetime.datetime.now()}')).hex(), base=16) & (pow(2, 64) - 1) - expire = pow(2, 31) - 1 - packed0 = issued << 192 | l2nonce << 128 | amount - packed1 = expire << 224 | ratio << 160 | int(funder, base=16) - - digest = web3.Web3.solidity_keccak( - ['bytes1', 'bytes1', 'address', 'bytes32', 'address', 'address', - 'bytes32', 'uint256', 'uint256', 'bytes32'], - [b'\x19', b'\x00', - self.lot_addr, b'\x00' * 31 + b'\x64', - self.token_addr, recipient, - web3.Web3.solidity_keccak(['bytes32'], [self.commitment]), packed0, - packed1, self.data]) - - sig = self.web3.eth.account.signHash(digest, private_key=key.hex()) - packed1 = packed1 << 1 | ((sig.v - 27) & 1) - - self.packed0 = packed0 - self.packed1 = packed1 - self.sig_r = Ticket.to_32byte_hex(sig.r) - self.sig_s = Ticket.to_32byte_hex(sig.s) - self.sig_v = (sig.v - 27) & 1 - - def digest(self, packed0 = None, packed1 = None): - _packed0 = self.packed0 if packed0 is None else packed0 - _packed1 = self.packed1 if packed1 is None else packed1 - _packed1 = _packed1 >> 1 - types = ['bytes1', 'bytes1', 'address', 'bytes32', 'address', 'address', - 'bytes32', 'uint256', 'uint256', 'bytes32'] - vals = [b'\x19', b'\x00', - self.lot_addr, b'\x00' * 31 + b'\x64', - self.token_addr, self.recipient, - bytes.fromhex(self.commitment[2:]), _packed0, - _packed1, self.data] - packed = encode_packed(types, vals) - return ethereum.utils.sha3(packed) - - - @staticmethod - def to_32byte_hex(val): - return web3.Web3.to_hex(web3.Web3.to_bytes(hexstr=val).rjust(32, b'\0')) - - def serialize_ticket(self): - return Ticket.to_32byte_hex(self.packed0)[2:] + Ticket.to_32byte_hex(self.packed1)[2:] + self.sig_r[2:] + self.sig_s[2:] - - @staticmethod - def deserialize_ticket(tstr, reveal = None, commitment = None, recipient = None, - lotaddr = '0x6dB8381b2B41b74E17F5D4eB82E8d5b04ddA0a82', - tokenaddr = '0x0000000000000000000000000000000000000000'): - tk = [tstr[i:i+64] for i in range(0, len(tstr), 64)] - print(tk) - tk_temp = Ticket(None, None, None, None, None, None, None, None, None, None) - tk_temp.packed0 = int(tk[0], base=16) - tk_temp.packed1 = int(tk[1], base=16) - tk_temp.amount = tk_temp.packed0 & uint128 - tk_temp.ratio = (tk_temp.packed1 >> 161) & uint64 - tk_temp.sig_r = tk[2] - tk_temp.sig_s = tk[3] - tk_temp.sig_v = tk_temp.packed1 & 1 - tk_temp.data = b'\x00' * 32 - tk_temp.reveal = Ticket.to_32byte_hex(reveal) - tk_temp.commitment = Ticket.to_32byte_hex(commitment) - tk_temp.lot_addr = lotaddr - tk_temp.token_addr = tokenaddr - tk_temp.recipient = recipient - digest = tk_temp.digest() - signer = ethereum.utils.checksum_encode(ethereum.utils.sha3(ethereum.utils.ecrecover_to_pub(digest, - tk_temp.sig_v, - bytes.fromhex(tk_temp.sig_r[2:]), - bytes.fromhex(tk_temp.sig_s[2:]) - ))[-20:]) - return tk_temp - - def is_winner(self, reveal): - ratio = uint64 & (self.packed1 >> 161) - issued_nonce = (self.packed0 >> 128) - hash = ethereum.utils.sha3(bytes.fromhex(reveal[2:]) + - issued_nonce.to_bytes(length=16, byteorder='big')) - comp = uint64 & int(hash.hex(), base=16) - if ratio < comp: - return False - return True - - def value(self): - return self.amount * self.ratio / uint64 - - def print_ticket(self): - amount = self.packed0 & uint128 - nonce = (self.packed0 >> 128) & uint64 - funder = addrtype & (self.packed1 >> 1) - ratio = uint64 & (self.packed1 >> 161) - print('Print_ticket():') - print(f'Face Value: {amount}') - print(f'Funder: {funder}') - print(f'Ratio: {ratio}') - +import datetime +from web3 import Web3 +from typing import Optional, Tuple + +class TicketError(Exception): + pass + +class Ticket: + def __init__(self, + packed0: int, + packed1: int, + sig_r: str, + sig_s: str, + reveal: Optional[str] = None, + commitment: Optional[str] = None, + recipient: Optional[str] = None, + lottery_addr: Optional[str] = None, + token_addr: str = "0x0000000000000000000000000000000000000000"): + self.packed0 = packed0 + self.packed1 = packed1 + self.sig_r = sig_r + self.sig_s = sig_s + self.sig_v = packed1 & 1 + self.reveal = reveal + self.commitment = commitment + self.recipient = recipient + self.lottery_addr = lottery_addr + self.token_addr = token_addr + self.data = b'\x00' * 32 # Fixed empty data field + + @classmethod + def deserialize(cls, + ticket_str: str, + reveal: Optional[str] = None, + commitment: Optional[str] = None, + recipient: Optional[str] = None, + lottery_addr: Optional[str] = None, + token_addr: str = "0x0000000000000000000000000000000000000000" + ) -> 'Ticket': + try: + if len(ticket_str) != 256: # 4 x 64 hex chars + raise TicketError("Invalid ticket format") + + parts = [ticket_str[i:i+64] for i in range(0, 256, 64)] + return cls( + packed0=int(parts[0], 16), + packed1=int(parts[1], 16), + sig_r=parts[2], + sig_s=parts[3], + reveal=reveal, + commitment=commitment, + recipient=recipient, + lottery_addr=lottery_addr, + token_addr=token_addr + ) + except Exception as e: + raise TicketError(f"Failed to deserialize ticket: {e}") + + def is_winner(self) -> bool: + if not self.reveal: + raise TicketError("No reveal value available") + + try: + ratio = (self.packed1 >> 161) & ((1 << 64) - 1) + issued_nonce = (self.packed0 >> 128) + hash_val = Web3.keccak( + Web3.to_bytes(hexstr=self.reveal[2:]) + + issued_nonce.to_bytes(length=16, byteorder='big') + ) + comp = ((1 << 64) - 1) & int(hash_val.hex(), 16) + return ratio >= comp + except Exception as e: + raise TicketError(f"Failed to check winning status: {e}") + + def face_value(self) -> int: + return self.packed0 & ((1 << 128) - 1) + + def verify_signature(self, expected_signer: str) -> bool: + if not all([self.commitment, self.recipient, self.lottery_addr]): + raise TicketError("Missing required fields for signature verification") + + try: + digest = Web3.solidity_keccak( + ['bytes1', 'bytes1', 'address', 'bytes32', 'address', 'address', + 'bytes32', 'uint256', 'uint256', 'bytes32'], + [b'\x19', b'\x00', + self.lottery_addr, + b'\x00' * 31 + b'\x64', + self.token_addr, + self.recipient, + Web3.solidity_keccak(['bytes32'], [self.commitment]), + self.packed0, + self.packed1 >> 1, + self.data] + ) + + recovered = Web3.eth.account.recover_message( + eth_message_hash=digest, + vrs=(self.sig_v + 27, int(self.sig_r, 16), int(self.sig_s, 16)) + ) + return recovered.lower() == expected_signer.lower() + except Exception as e: + raise TicketError(f"Failed to verify signature: {e}")