Skip to content

Commit

Permalink
Add pay-per-request rpc service
Browse files Browse the repository at this point in the history
  • Loading branch information
danopato committed Nov 25, 2024
1 parent 68753d1 commit 69a702d
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 54 deletions.
122 changes: 68 additions & 54 deletions gai-backend/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,60 +26,74 @@ async def load_from_file(self, config_path: str) -> Dict[str, Any]:
raise ConfigError(f"Failed to load config file: {e}")

def process_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
if 'inference' not in config:
raise ConfigError("Missing required 'inference' section")

if 'endpoints' not in config['inference']:
raise ConfigError("Missing required 'endpoints' in inference config")

endpoints = config['inference']['endpoints']
if not endpoints:
raise ConfigError("No inference endpoints configured")

total_models = 0

for endpoint_id, endpoint in endpoints.items():
required_fields = ['api_type', 'url', 'api_key', 'models']
missing = [field for field in required_fields if field not in endpoint]
if missing:
raise ConfigError(f"Endpoint {endpoint_id} missing required fields: {', '.join(missing)}")
if not isinstance(endpoint['models'], list):
raise ConfigError(f"Endpoint {endpoint_id} 'models' must be a list")
if not endpoint['models']:
raise ConfigError(f"Endpoint {endpoint_id} has no models configured")

total_models += len(endpoint['models'])

for model in endpoint['models']:
required_model_fields = ['id', 'pricing']
missing = [field for field in required_model_fields if field not in model]
if missing:
raise ConfigError(f"Model in endpoint {endpoint_id} missing required fields: {', '.join(missing)}")
if 'params' not in model:
model['params'] = {}

pricing = model['pricing']
if 'type' not in pricing:
raise ConfigError(f"Model {model['id']} missing required pricing type")

required_pricing_fields = {
'fixed': ['input_price', 'output_price'],
'cost_plus': ['backend_input', 'backend_output', 'input_markup', 'output_markup'],
'multiplier': ['backend_input', 'backend_output', 'input_multiplier', 'output_multiplier']
}

if pricing['type'] not in required_pricing_fields:
raise ConfigError(f"Invalid pricing type for model {model['id']}: {pricing['type']}")

missing = [field for field in required_pricing_fields[pricing['type']]
if field not in pricing]
if missing:
raise ConfigError(f"Model {model['id']} pricing missing required fields: {', '.join(missing)}")

if total_models == 0:
raise ConfigError("No models configured across all endpoints")

return config
if 'inference' not in config and 'rpc' not in config:
raise ConfigError("Config must contain either 'inference' or 'rpc' section")

if 'inference' in config:
if 'endpoints' not in config['inference']:
raise ConfigError("Missing required 'endpoints' in inference config")

endpoints = config['inference']['endpoints']
if not endpoints:
raise ConfigError("No inference endpoints configured")

total_models = 0

for endpoint_id, endpoint in endpoints.items():
required_fields = ['api_type', 'url', 'api_key', 'models']
missing = [field for field in required_fields if field not in endpoint]
if missing:
raise ConfigError(f"Endpoint {endpoint_id} missing required fields: {', '.join(missing)}")
if not isinstance(endpoint['models'], list):
raise ConfigError(f"Endpoint {endpoint_id} 'models' must be a list")
if not endpoint['models']:
raise ConfigError(f"Endpoint {endpoint_id} has no models configured")

total_models += len(endpoint['models'])

for model in endpoint['models']:
required_model_fields = ['id', 'pricing']
missing = [field for field in required_model_fields if field not in model]
if missing:
raise ConfigError(f"Model in endpoint {endpoint_id} missing required fields: {', '.join(missing)}")
if 'params' not in model:
model['params'] = {}

pricing = model['pricing']
if 'type' not in pricing:
raise ConfigError(f"Model {model['id']} missing required pricing type")

required_pricing_fields = {
'fixed': ['input_price', 'output_price'],
'cost_plus': ['backend_input', 'backend_output', 'input_markup', 'output_markup'],
'multiplier': ['backend_input', 'backend_output', 'input_multiplier', 'output_multiplier']
}

if pricing['type'] not in required_pricing_fields:
raise ConfigError(f"Invalid pricing type for model {model['id']}: {pricing['type']}")

missing = [field for field in required_pricing_fields[pricing['type']]
if field not in pricing]
if missing:
raise ConfigError(f"Model {model['id']} pricing missing required fields: {', '.join(missing)}")

if total_models == 0:
raise ConfigError("No models configured across all endpoints")

if 'rpc' in config:
rpc = config['rpc']
required_fields = ['provider_url', 'provider_key', 'prices', 'pricing']
missing = [f for f in required_fields if f not in rpc]
if missing:
raise ConfigError(f"RPC config missing required fields: {', '.join(missing)}")

pricing = rpc['pricing']
required_pricing = ['base_unit', 'credit_to_usd', 'min_usd_charge']
missing = [f for f in required_pricing if f not in pricing]
if missing:
raise ConfigError(f"RPC pricing missing required fields: {', '.join(missing)}")

return config

async def write_config(self, config: Dict[str, Any], force: bool = False):
try:
Expand Down
104 changes: 104 additions & 0 deletions gai-backend/rpc_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from redis.asyncio import Redis
from contextlib import asynccontextmanager
import httpx
import os

from config_manager import ConfigManager
from billing import StrictRedisBilling

@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
redis_url = os.environ['ORCHID_GENAI_REDIS_URL']
redis = Redis.from_url(redis_url, decode_responses=True)
app.state.api = RPCAPI(redis)
await app.state.api.init()
yield
# Shutdown
await redis.close()

app = FastAPI(lifespan=lifespan)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

class RPCAPI:
def __init__(self, redis: Redis):
self.redis = redis
self.config_manager = ConfigManager(redis)
self.billing = StrictRedisBilling(redis)

async def init(self):
await self.billing.init()
self.config = await self.config_manager.load_config()
if 'rpc' not in self.config:
raise ValueError("RPC config required")

async def handle_rpc(self, request: dict, session_id: str):
method = request.get('method')

if method not in self.config['rpc']['prices']:
raise HTTPException(400, "Method not supported")

credits = self.config['rpc']['prices'][method]
credit_to_usd = self.config['rpc']['pricing']['credit_to_usd']
price_usd = credits * credit_to_usd

print(f"Method {method}: {credits} credits = ${price_usd}")

# Check balance and debit
balance = await self.billing.balance(session_id)
print(f"Current balance: ${balance}")

if balance < price_usd:
raise HTTPException(402, f"Insufficient balance: ${balance} < ${price_usd}")

await self.billing.debit(session_id, amount=price_usd)

provider_url = f"{self.config['rpc']['provider_url']}/{self.config['rpc']['provider_key']}"
print(f"Sending request to provider: {provider_url}")
print(f"Request body: {request}")

# Create new client for each request
async with httpx.AsyncClient() as client:
resp = await client.post(
provider_url,
json=request
)
print(f"Provider response status: {resp.status_code}")
print(f"Provider response content: {resp.content}")
if resp.content:
return resp.json()
else:
raise HTTPException(502, "Empty response from provider")

@app.api_route("/", methods=["GET", "POST", "OPTIONS"])
async def rpc_endpoint(request: Request, token: str):
if request.method == "OPTIONS":
return {}

# For GET requests (like when MetaMask checks chainId)
if request.method == "GET":
# Return a simple JSON-RPC response
return {
"jsonrpc": "2.0",
"id": None,
"result": {
"chainId": "0x1" # Mainnet
}
}

# Existing POST handling
body = await request.json()
return await app.state.api.handle_rpc(body, token)

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

0 comments on commit 69a702d

Please sign in to comment.