Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions config/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ providers:
responses: false
base_url: https://openrouter.ai/api/v1
api_key: sk-openrouter-api-key
groq1:
type: openai
base_url: https://api.groq.com/openai/v1
api_key: sk-groq-api-key
groq2:
type: openai
base_url: https://api.groq.com/openai/v1
keys:
- api_key: sk-groq-api-key-1
- api_key: sk-groq-api-key-2

models:
openai/gpt-5:
Expand Down
7 changes: 7 additions & 0 deletions src/types/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ export interface LMRouterConfigProvider {
responses?: boolean;
base_url?: string;
api_key: string;
keys?: LMRouterConfigProviderKey[];
}

export interface LMRouterConfigProviderKey {
api_key: string;
weight?: number;
}

export interface LMRouterConfigModelProviderPricingFixed {
Expand Down Expand Up @@ -113,6 +119,7 @@ export type LMRouterConfigModelProviderPricing =
export interface LMRouterConfigModelProvider {
provider: string;
model: string;
weight?: number;
context_window?: number;
max_tokens?: number;
responses_only?: boolean;
Expand Down
103 changes: 103 additions & 0 deletions src/utils/load-balancing.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import type {
LMRouterConfigModelProvider,
LMRouterConfigProviderKey,
} from "../types/config.js";

/**
* Implements a smooth weighted round-robin load balancer.
* This ensures a smooth and predictable distribution of requests based on provider weights,
* avoiding the potential for request bursts that can occur with purely random selection.
*
* The state (current weights) is stored in a global map, making it suitable for
* single-process environments.
*/
export class LoadBalancer {
// state_key -> current_weight
private static providerWeights = new Map<string, number>();
private static keyWeights = new Map<string, number>();

public static getOrderedProviders(
providers: LMRouterConfigModelProvider[],
): LMRouterConfigModelProvider[] {
if (!providers || providers.length === 0) {
return [];
}

if (providers.length === 1) {
return providers;
}

const totalWeight = providers.reduce((acc, p) => acc + (p.weight ?? 1), 0);

// Find the provider with the highest current weight
let bestProvider: LMRouterConfigModelProvider | null = null;
let maxWeight = -Infinity;

for (const provider of providers) {
const providerId = `${provider.provider}:${provider.model}`;
const currentWeight = this.providerWeights.get(providerId) ?? 0;
const newWeight = currentWeight + (provider.weight ?? 1);
this.providerWeights.set(providerId, newWeight);

if (newWeight > maxWeight) {
maxWeight = newWeight;
bestProvider = provider;
}
}

if (bestProvider) {
const providerId = `${bestProvider.provider}:${bestProvider.model}`;
// Decrease the best provider's weight by the total weight
this.providerWeights.set(providerId, maxWeight - totalWeight);

// Sort providers to try the best one first, then the rest
return [...providers].sort((a, b) =>
a === bestProvider ? -1 : b === bestProvider ? 1 : 0,
);
}

// Fallback to the original list if something goes wrong
return providers;
}

public static getApiKey(
providerName: string,
keys?: LMRouterConfigProviderKey[],
): string | undefined {
if (!keys || keys.length === 0) {
return undefined;
}

if (keys.length === 1) {
return keys[0].api_key;
}

const totalWeight = keys.reduce((acc, key) => acc + (key.weight ?? 1), 0);

let bestKey: LMRouterConfigProviderKey | null = null;
let maxWeight = -Infinity;

for (const key of keys) {
// Use a unique ID for each key within a provider
const keyId = `${providerName}:${key.api_key.slice(-4)}`;
const currentWeight = this.keyWeights.get(keyId) ?? 0;
const newWeight = currentWeight + (key.weight ?? 1);
this.keyWeights.set(keyId, newWeight);

if (newWeight > maxWeight) {
maxWeight = newWeight;
bestKey = key;
}
}

if (bestKey) {
const keyId = `${providerName}:${bestKey.api_key.slice(-4)}`;
// Decrease the best key's weight by the total weight
this.keyWeights.set(keyId, maxWeight - totalWeight);
return bestKey.api_key;
}

// Fallback to the last key if something goes wrong
return keys[keys.length - 1].api_key;
}
}
24 changes: 20 additions & 4 deletions src/utils/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { getConnInfo as getConnInfoNode } from "@hono/node-server/conninfo";
import { recordApiCall } from "./billing.js";
import { TimeKeeper } from "./chrono.js";
import { getConfig } from "./config.js";
import { LoadBalancer } from "./load-balancing.js";
import type {
LMRouterConfigModel,
LMRouterConfigModelProvider,
Expand Down Expand Up @@ -92,6 +93,7 @@ export const iterateModelProviders = async (
): Promise<any> => {
const cfg = getConfig(c);
let error: any = null;
let result: any = null;

if (!c.var.model) {
return c.json(
Expand All @@ -104,20 +106,30 @@ export const iterateModelProviders = async (
);
}

for (const providerCfg of c.var.model.providers) {
// Get the providers list ordered by the smooth weighted round-robin algorithm.
// This provides both load balancing and a predictable order for failover.
const orderedProviders = LoadBalancer.getOrderedProviders(
c.var.model.providers,
);

for (const providerCfg of orderedProviders) {
const provider = cfg.providers[providerCfg.provider];
if (!provider) {
continue;
}

const hydratedProvider = { ...provider };
hydratedProvider.api_key =
c.var.auth?.type === "byok" ? c.var.auth.byok : provider.api_key;
const byok = c.var.auth?.type === "byok" ? c.var.auth.byok : undefined;
hydratedProvider.api_key = byok
? byok
: (LoadBalancer.getApiKey(providerCfg.provider, provider.keys) ??
provider.api_key);

const timeKeeper = new TimeKeeper();
try {
timeKeeper.record();
return await cb(providerCfg, hydratedProvider);
result = await cb(providerCfg, hydratedProvider);
break;
} catch (e) {
timeKeeper.record();
await recordApiCall(
Expand All @@ -138,6 +150,10 @@ export const iterateModelProviders = async (
}
}

if (result) {
return result;
}

if (error) {
return c.json(
{
Expand Down