Skip to content

Commit

Permalink
add guard middleware
Browse files Browse the repository at this point in the history
Signed-off-by: oilbeater <[email protected]>
  • Loading branch information
oilbeater committed Oct 3, 2024
1 parent 0c4240e commit 479fb6f
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 8 deletions.
29 changes: 29 additions & 0 deletions src/middlewares/guard.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { Context, MiddlewareHandler, Next } from "hono";
import { AppContext } from ".";

const denyRequestPatterns = [
'password',
];

const denyResponsePatterns = [
'password',
];


// The guard middleware is used to protect the API by checking if the request match the specific regex.
// If so it returns message "Rejected due to inappropriate content" with 403 status code.
export const guardMiddleware: MiddlewareHandler = async (c: Context<AppContext>, next: Next) => {
const requestText = await c.req.text();
if (denyRequestPatterns.some(pattern => new RegExp(pattern, 'i').test(requestText))) {
return c.text('Rejected due to inappropriate content', 403);
}

await next();

if (c.res.status === 200 && c.res.headers.get('Content-Type')?.includes('application/json')) {
const responseText = await c.res.clone().text();
if (denyResponsePatterns.some(pattern => new RegExp(pattern, 'i').test(responseText))) {
return c.text('Rejected due to inappropriate content', 403);
}
}
}
2 changes: 1 addition & 1 deletion src/middlewares/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export { bufferMiddleware } from './buffer';
export { loggingMiddleware } from './logging';
export { virtualKeyMiddleware } from './virtualKey';
export { rateLimiterMiddleware } from './rateLimiter';

export { guardMiddleware } from './guard';
export type AppContext = {
Bindings: Env,
Variables: {
Expand Down
5 changes: 3 additions & 2 deletions src/providers/azureOpenAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import {
bufferMiddleware,
loggingMiddleware,
virtualKeyMiddleware,
rateLimiterMiddleware
rateLimiterMiddleware,
guardMiddleware
} from '../middlewares';

const BasePath = '/azure-openai/:resource_name/deployments/:deployment_name';
Expand All @@ -21,7 +22,7 @@ const initMiddleware = async (c: Context, next: Next) => {
};


azureOpenAIRoute.use(initMiddleware, metricsMiddleware, loggingMiddleware, bufferMiddleware, virtualKeyMiddleware, rateLimiterMiddleware, cacheMiddleware);
azureOpenAIRoute.use(initMiddleware, metricsMiddleware, loggingMiddleware, bufferMiddleware, virtualKeyMiddleware, rateLimiterMiddleware, guardMiddleware, cacheMiddleware);

azureOpenAIRoute.post('/*', async (c: Context) => {
return azureOpenAIProvider.handleRequest(c);
Expand Down
24 changes: 19 additions & 5 deletions test/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ describe('Welcome to Malacca worker', () => {
});

const url = `https://example.com/azure-openai/${import.meta.env.VITE_AZURE_RESOURCE_NAME}/deployments/${import.meta.env.VITE_AZURE_DEPLOYMENT_NAME}/chat/completions?api-version=2024-07-01-preview`;
const createRequestBody = (stream: boolean) => `
const createRequestBody = (stream: boolean, placeholder: string) => `
{
"messages": [
{
Expand All @@ -37,7 +37,7 @@ const createRequestBody = (stream: boolean) => `
"content": [
{
"type": "text",
"text": "Tell me a very short story about Malacca"
"text": "Tell me a very short story about ${placeholder}"
}
]
}
Expand All @@ -50,7 +50,7 @@ const createRequestBody = (stream: boolean) => `

describe('Test Cache', () => {
it('with cache first response should with no header malacca-cache-status and following response with hit', async () => {
const body = createRequestBody(false);
const body = createRequestBody(false, 'Malacca');
let start = Date.now();
let response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'api-key': 'oilbeater' } });
const value = await response.json()
Expand All @@ -73,7 +73,7 @@ describe('Test Cache', () => {
});

it('Test stream with cache', async () => {
const body = createRequestBody(true);
const body = createRequestBody(true, 'Malacca');
let start = Date.now();
let response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'api-key': 'oilbeater' } });
const value = await response.text()
Expand Down Expand Up @@ -131,10 +131,24 @@ describe('Test Virtual Key', () => {
it('should return 401 for invalid api key', async () => {
const response = await SELF.fetch(url, {
method: 'POST',
body: createRequestBody(true),
body: createRequestBody(true, 'Malacca'),
headers: { 'Content-Type': 'application/json', 'api-key': 'invalid-key' }
});

expect(response.status).toBe(401);
});
});

describe('Test Guard', () => {
it('should return 403 for deny request', async () => {
const response = await SELF.fetch(url, {
method: 'POST',
body: createRequestBody(true, 'password'),
headers: { 'Content-Type': 'application/json', 'api-key': 'oilbeater' }
});

expect(response.status).toBe(403);
});
});


0 comments on commit 479fb6f

Please sign in to comment.