Skip to content

Commit

Permalink
add middleware support for deepseek
Browse files Browse the repository at this point in the history
Signed-off-by: oilbeater <[email protected]>
  • Loading branch information
oilbeater committed Oct 11, 2024
1 parent 8bb13a4 commit 169b234
Show file tree
Hide file tree
Showing 13 changed files with 240 additions and 21 deletions.
5 changes: 3 additions & 2 deletions src/middlewares/analytics.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Context, MiddlewareHandler, Next } from 'hono';
import { AppContext } from '.';
import { AppContext, setMiddlewares } from '.';

export function recordAnalytics(
c: Context<AppContext>,
Expand All @@ -13,7 +13,7 @@ export function recordAnalytics(
const getTokenCount = c.get('getTokenCount');
const { input_tokens, output_tokens } = typeof getTokenCount === 'function' ? getTokenCount(c) : { input_tokens: 0, output_tokens: 0 };

// console.log(endpoint, c.req.path, modelName, input_tokens, output_tokens, c.get('malacca-cache-status') || 'miss', c.res.status);
console.log(endpoint, c.req.path, modelName, input_tokens, output_tokens, c.get('malacca-cache-status') || 'miss', c.res.status);

if (c.env.MALACCA) {
c.env.MALACCA.writeDataPoint({
Expand All @@ -25,6 +25,7 @@ export function recordAnalytics(
}

export const metricsMiddleware: MiddlewareHandler = async (c: Context<AppContext>, next: Next) => {
setMiddlewares(c, 'metrics');
const startTime = Date.now();
await next();

Expand Down
3 changes: 2 additions & 1 deletion src/middlewares/buffer.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { Context, MiddlewareHandler, Next } from 'hono'
import { AppContext } from '.';
import { AppContext, setMiddlewares } from '.';

export const bufferMiddleware: MiddlewareHandler = async (c: Context<AppContext>, next: Next) => {
setMiddlewares(c, 'buffer');
let buffer = ''
let resolveBuffer!: () => void
const bufferPromise = new Promise<void>((resolve) => {
Expand Down
3 changes: 2 additions & 1 deletion src/middlewares/cache.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Context, MiddlewareHandler, Next } from "hono";
import { AppContext } from '.';
import { AppContext, setMiddlewares } from '.';

export async function generateCacheKey(urlWithQueryParams: string, body: string): Promise<string> {
const cacheKey = await crypto.subtle.digest(
Expand All @@ -12,6 +12,7 @@ export async function generateCacheKey(urlWithQueryParams: string, body: string)
}

export const cacheMiddleware: MiddlewareHandler = async (c: Context<AppContext>, next: Next) => {
setMiddlewares(c, 'cache');
const cacheKeyHex = await generateCacheKey(c.req.url, await c.req.text());
const response = await c.env.MALACCA_CACHE.get(cacheKeyHex, "stream");
if (response) {
Expand Down
3 changes: 2 additions & 1 deletion src/middlewares/fallback.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { Context, Next } from 'hono';
import { AppContext } from '.';
import { AppContext, setMiddlewares } from '.';

export const fallbackMiddleware = async (c: Context<AppContext>, next: Next) => {
setMiddlewares(c, 'fallback');
try {
await next();

Expand Down
3 changes: 2 additions & 1 deletion src/middlewares/guard.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Context, MiddlewareHandler, Next } from "hono";
import { AppContext } from ".";
import { AppContext, setMiddlewares } from ".";

const denyRequestPatterns = [
'password',
Expand All @@ -13,6 +13,7 @@ const denyResponsePatterns = [
// The guard middleware is used to protect the API by checking if the request match the specific regex.
// If so it returns message "Rejected due to inappropriate content" with 403 status code.
export const guardMiddleware: MiddlewareHandler = async (c: Context<AppContext>, next: Next) => {
setMiddlewares(c, 'guard');
const requestText = await c.req.text();
if (denyRequestPatterns.some(pattern => new RegExp(pattern, 'i').test(requestText))) {
return c.text('Rejected due to inappropriate content', 403);
Expand Down
10 changes: 10 additions & 0 deletions src/middlewares/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export { fallbackMiddleware } from './fallback';
export interface AppContext {
Bindings: Env,
Variables: {
middlewares: string[],
endpoint: string,
'malacca-cache-status': string,
bufferPromise: Promise<any>,
Expand All @@ -19,5 +20,14 @@ export interface AppContext {
realKey: string,
getModelName: (c: Context) => string,
getTokenCount: (c: Context) => { input_tokens: number, output_tokens: number },
getVirtualKey: (c: Context) => string,
}
}

export function setMiddlewares(c: Context, name: string) {
if (!c.get('middlewares')) {
c.set('middlewares', [name]);
} else {
c.set('middlewares', [...c.get('middlewares'), name]);
}
}
3 changes: 2 additions & 1 deletion src/middlewares/logging.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { Context, Next } from 'hono';
import { AppContext } from '.';
import { AppContext, setMiddlewares } from '.';

export const loggingMiddleware = async (c: Context<AppContext>, next: Next) => {
setMiddlewares(c, 'logging');
await next();

// Log request and response
Expand Down
3 changes: 2 additions & 1 deletion src/middlewares/rateLimiter.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { Context, Next } from "hono";
import { AppContext } from '.';
import { AppContext, setMiddlewares } from '.';

export const rateLimiterMiddleware = async (c: Context<AppContext>, next: Next) => {
setMiddlewares(c, 'rateLimiter');
const key = c.req.header('api-key') || '';
const { success } = await c.env.MY_RATE_LIMITER.limit({ key: key })
if (!success) {
Expand Down
5 changes: 3 additions & 2 deletions src/middlewares/virtualKey.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { Context, Next } from "hono";
import { AppContext } from '.';
import { AppContext, setMiddlewares } from '.';

export const virtualKeyMiddleware = async (c: Context<AppContext>, next: Next) => {
const apiKey = c.req.header('api-key') || '';
setMiddlewares(c, 'virtualKey');
const apiKey = c.get('getVirtualKey')(c);
const realKey = await c.env.MALACCA_USER.get(apiKey);
if (!realKey) {
return c.text('Unauthorized', 401);
Expand Down
30 changes: 26 additions & 4 deletions src/providers/azureOpenAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,29 @@ const ProviderName = 'azure-openai';
const azureOpenAIRoute = new Hono();

const initMiddleware = async (c: Context, next: Next) => {
if (!c.get('middlewares')) {
c.set('middlewares', ['init']);
} else {
c.set('middlewares', [...c.get('middlewares'), 'init']);
}
c.set('endpoint', ProviderName);
c.set('getModelName', getModelName);
c.set('getTokenCount', getTokenCount);
c.set('getVirtualKey', getVirtualKey);
await next();
};

azureOpenAIRoute.use(initMiddleware, metricsMiddleware, loggingMiddleware, bufferMiddleware, virtualKeyMiddleware, rateLimiterMiddleware, guardMiddleware, cacheMiddleware, fallbackMiddleware);
azureOpenAIRoute.use(
initMiddleware,
metricsMiddleware,
loggingMiddleware,
bufferMiddleware,
virtualKeyMiddleware,
rateLimiterMiddleware,
guardMiddleware,
cacheMiddleware,
fallbackMiddleware
);

azureOpenAIRoute.post('/*', async (c: Context) => {
return azureOpenAIProvider.handleRequest(c);
Expand All @@ -43,9 +59,11 @@ export const azureOpenAIProvider: AIProvider = {
const urlWithQueryParams = `${azureEndpoint}?${queryParams}`;

const headers = new Headers(c.req.header());
const apiKey: string = c.get('realKey');
if (apiKey) {
headers.set('api-key', apiKey);
if (c.get('middlewares')?.includes('virtualKey')) {
const apiKey: string = c.get('realKey');
if (apiKey) {
headers.set('api-key', apiKey);
}
}
const response = await fetch(urlWithQueryParams, {
method: c.req.method,
Expand Down Expand Up @@ -107,4 +125,8 @@ function getTokenCount(c: Context): { input_tokens: number, output_tokens: numbe
}
}
return { input_tokens: 0, output_tokens: 0 }
}

function getVirtualKey(c: Context): string {
return c.req.header('api-key') || '';
}
53 changes: 46 additions & 7 deletions src/providers/deepseek.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,22 @@ const initMiddleware = async (c: Context, next: Next) => {
c.set('endpoint', ProviderName);
c.set('getModelName', getModelName);
c.set('getTokenCount', getTokenCount);
c.set('getVirtualKey', getVirtualKey);
await next();
};

deepseekRoute.use(
initMiddleware,
metricsMiddleware,
loggingMiddleware,
bufferMiddleware,
virtualKeyMiddleware,
rateLimiterMiddleware,
guardMiddleware,
cacheMiddleware,
fallbackMiddleware
);

deepseekRoute.post('/*', async (c: Context) => {
return deepseekProvider.handleRequest(c);
});
Expand All @@ -39,28 +52,37 @@ export const deepseekProvider: AIProvider = {
handleRequest: async (c: Context) => {
const functionName = c.req.path.slice(`/deepseek/`.length);
const deepseekEndpoint = `https://api.deepseek.com/${functionName}`;
console.log(`DeepSeek endpoint: ${deepseekEndpoint}`);

const headers = new Headers(c.req.header());
if (c.get('middlewares')?.includes('virtualKey')) {
const apiKey: string = c.get('realKey');
if (apiKey) {
headers.set('Authorization', `Bearer ${apiKey}`);
}
}

const response = await fetch(deepseekEndpoint, {
method: c.req.method,
body: JSON.stringify(await c.req.json()),
headers: c.req.header()
headers: headers
});

return response;
}
};

function getModelName(c: Context): string {
const model = c.req.param('model');
const body = c.get('reqBuffer') || '{}';
const model = JSON.parse(body).model;
return model || "unknown";
}

function getTokenCount(c: Context): { input_tokens: number, output_tokens: number } {
const buf = c.get('buffer') || "";
if (c.res.status === 200) {
try {
const jsonResponse = JSON.parse(buf);
if (c.res.headers.get('content-type') === 'application/json') {
try {
const jsonResponse = JSON.parse(buf);
const usage = jsonResponse.usage;
if (usage) {
return {
Expand All @@ -69,8 +91,25 @@ function getTokenCount(c: Context): { input_tokens: number, output_tokens: numbe
};
}
} catch (error) {
console.error("Error parsing response:", error);
console.error("Error parsing response:", error);
}
}
}
else {
const output = buf.trim().split('\n\n').at(-2);
if (output && output.startsWith('data: ')) {
const usage_message = JSON.parse(output.slice('data: '.length));
return {
input_tokens: usage_message.usage.prompt_tokens || 0,
output_tokens: usage_message.usage.completion_tokens || 0
};
}
}
}
return { input_tokens: 0, output_tokens: 0 };
}

function getVirtualKey(c: Context): string {
const authHeader = c.req.header('Authorization') || '';
return authHeader.startsWith('Bearer ') ? authHeader.slice(7) : '';
}

4 changes: 4 additions & 0 deletions src/providers/workersAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ function getModelName(c: Context) {
function getTokenCount(c: Context) {
return { input_tokens: 0, output_tokens: 0 };
}

function getVirtualKey(c: Context) {
return '';
}
Loading

0 comments on commit 169b234

Please sign in to comment.