diff --git a/config/config.example.yaml b/config/config.example.yaml index ee69193..101ae9f 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -33,6 +33,16 @@ providers: responses: false base_url: https://openrouter.ai/api/v1 api_key: sk-openrouter-api-key + groq1: + type: openai + base_url: https://api.groq.com/openai/v1 + api_key: sk-groq-api-key + groq2: + type: openai + base_url: https://api.groq.com/openai/v1 + keys: + - api_key: sk-groq-api-key-1 + - api_key: sk-groq-api-key-2 models: openai/gpt-5: diff --git a/src/types/config.ts b/src/types/config.ts index bfaf928..ed70aea 100644 --- a/src/types/config.ts +++ b/src/types/config.ts @@ -80,6 +80,12 @@ export interface LMRouterConfigProvider { responses?: boolean; base_url?: string; api_key: string; + keys?: LMRouterConfigProviderKey[]; +} + +export interface LMRouterConfigProviderKey { + api_key: string; + weight?: number; } export interface LMRouterConfigModelProviderPricingFixed { @@ -113,6 +119,7 @@ export type LMRouterConfigModelProviderPricing = export interface LMRouterConfigModelProvider { provider: string; model: string; + weight?: number; context_window?: number; max_tokens?: number; responses_only?: boolean; diff --git a/src/utils/load-balancing.ts b/src/utils/load-balancing.ts new file mode 100644 index 0000000..0dfd61d --- /dev/null +++ b/src/utils/load-balancing.ts @@ -0,0 +1,103 @@ +import type { + LMRouterConfigModelProvider, + LMRouterConfigProviderKey, +} from "../types/config.js"; + +/** + * Implements a smooth weighted round-robin load balancer. + * This ensures a smooth and predictable distribution of requests based on provider weights, + * avoiding the potential for request bursts that can occur with purely random selection. + * + * The state (current weights) is stored in a global map, making it suitable for + * single-process environments. + */ +export class LoadBalancer { + // state_key -> current_weight + private static providerWeights = new Map(); + private static keyWeights = new Map(); + + public static getOrderedProviders( + providers: LMRouterConfigModelProvider[], + ): LMRouterConfigModelProvider[] { + if (!providers || providers.length === 0) { + return []; + } + + if (providers.length === 1) { + return providers; + } + + const totalWeight = providers.reduce((acc, p) => acc + (p.weight ?? 1), 0); + + // Find the provider with the highest current weight + let bestProvider: LMRouterConfigModelProvider | null = null; + let maxWeight = -Infinity; + + for (const provider of providers) { + const providerId = `${provider.provider}:${provider.model}`; + const currentWeight = this.providerWeights.get(providerId) ?? 0; + const newWeight = currentWeight + (provider.weight ?? 1); + this.providerWeights.set(providerId, newWeight); + + if (newWeight > maxWeight) { + maxWeight = newWeight; + bestProvider = provider; + } + } + + if (bestProvider) { + const providerId = `${bestProvider.provider}:${bestProvider.model}`; + // Decrease the best provider's weight by the total weight + this.providerWeights.set(providerId, maxWeight - totalWeight); + + // Sort providers to try the best one first, then the rest + return [...providers].sort((a, b) => + a === bestProvider ? -1 : b === bestProvider ? 1 : 0, + ); + } + + // Fallback to the original list if something goes wrong + return providers; + } + + public static getApiKey( + providerName: string, + keys?: LMRouterConfigProviderKey[], + ): string | undefined { + if (!keys || keys.length === 0) { + return undefined; + } + + if (keys.length === 1) { + return keys[0].api_key; + } + + const totalWeight = keys.reduce((acc, key) => acc + (key.weight ?? 1), 0); + + let bestKey: LMRouterConfigProviderKey | null = null; + let maxWeight = -Infinity; + + for (const key of keys) { + // Use a unique ID for each key within a provider + const keyId = `${providerName}:${key.api_key.slice(-4)}`; + const currentWeight = this.keyWeights.get(keyId) ?? 0; + const newWeight = currentWeight + (key.weight ?? 1); + this.keyWeights.set(keyId, newWeight); + + if (newWeight > maxWeight) { + maxWeight = newWeight; + bestKey = key; + } + } + + if (bestKey) { + const keyId = `${providerName}:${bestKey.api_key.slice(-4)}`; + // Decrease the best key's weight by the total weight + this.keyWeights.set(keyId, maxWeight - totalWeight); + return bestKey.api_key; + } + + // Fallback to the last key if something goes wrong + return keys[keys.length - 1].api_key; + } +} diff --git a/src/utils/utils.ts b/src/utils/utils.ts index b42e516..527ca14 100644 --- a/src/utils/utils.ts +++ b/src/utils/utils.ts @@ -9,6 +9,7 @@ import { getConnInfo as getConnInfoNode } from "@hono/node-server/conninfo"; import { recordApiCall } from "./billing.js"; import { TimeKeeper } from "./chrono.js"; import { getConfig } from "./config.js"; +import { LoadBalancer } from "./load-balancing.js"; import type { LMRouterConfigModel, LMRouterConfigModelProvider, @@ -92,6 +93,7 @@ export const iterateModelProviders = async ( ): Promise => { const cfg = getConfig(c); let error: any = null; + let result: any = null; if (!c.var.model) { return c.json( @@ -104,20 +106,30 @@ export const iterateModelProviders = async ( ); } - for (const providerCfg of c.var.model.providers) { + // Get the providers list ordered by the smooth weighted round-robin algorithm. + // This provides both load balancing and a predictable order for failover. + const orderedProviders = LoadBalancer.getOrderedProviders( + c.var.model.providers, + ); + + for (const providerCfg of orderedProviders) { const provider = cfg.providers[providerCfg.provider]; if (!provider) { continue; } const hydratedProvider = { ...provider }; - hydratedProvider.api_key = - c.var.auth?.type === "byok" ? c.var.auth.byok : provider.api_key; + const byok = c.var.auth?.type === "byok" ? c.var.auth.byok : undefined; + hydratedProvider.api_key = byok + ? byok + : (LoadBalancer.getApiKey(providerCfg.provider, provider.keys) ?? + provider.api_key); const timeKeeper = new TimeKeeper(); try { timeKeeper.record(); - return await cb(providerCfg, hydratedProvider); + result = await cb(providerCfg, hydratedProvider); + break; } catch (e) { timeKeeper.record(); await recordApiCall( @@ -138,6 +150,10 @@ export const iterateModelProviders = async ( } } + if (result) { + return result; + } + if (error) { return c.json( {