diff --git a/scripts/generate-acl.mjs b/scripts/generate-acl.mjs index f9ca249..831dbeb 100644 --- a/scripts/generate-acl.mjs +++ b/scripts/generate-acl.mjs @@ -46,16 +46,16 @@ function extractAclMapping() { if (acl && Array.isArray(acl.admin)) { aclMapping.acl.admin = [...aclMapping.acl.admin, ...acl.admin]; - const recordIds = getAlgorithmRecordIds(provider.name); - // Assign each algorithm record id with the provider ACL - for (const recordId of recordIds) { - if (aclMapping.records[recordId]) { - aclMapping.records[recordId] = [ - ...aclMapping.records[recordId], + const scenarioIds = getBenchmarkScenarioIds(provider.name); + // Assign each benchmark scenario id with the provider ACL + for (const scenarioId of scenarioIds) { + if (aclMapping.records[scenarioId]) { + aclMapping.records[scenarioId] = [ + ...aclMapping.records[scenarioId], ...acl.admin, ]; } else { - aclMapping.records[recordId] = acl.admin; + aclMapping.records[scenarioId] = acl.admin; } } } else { @@ -76,32 +76,39 @@ function extractAclMapping() { } } -function getAlgorithmRecordIds(providerDir) { +function getBenchmarkScenarioIds(providerDir) { try { const targetDir = path.join(ALGORITHM_CATALOG_DIR, providerDir); - const records = fs + const scenarios = fs .readdirSync(targetDir, { recursive: true }) .map((file) => file.toString()) .filter( (file) => file.endsWith(".json") && - (file.includes("/records/") || file.includes("\\records\\")), // support linux and windows based path + (file.includes("/benchmark_scenarios/") || + file.includes("\\benchmark_scenarios\\")), // support linux and windows based path ); - const recordIds = []; + const scenarioIds = []; - for (const recordFile of records) { - const recordPath = path.join(targetDir, recordFile); - const recordContent = fs.readFileSync(recordPath, "utf-8"); - const recordJson = JSON.parse(recordContent); - recordIds.push(recordJson.id); + for (const scenarioFile of scenarios) { + const scenarioPath = path.join(targetDir, scenarioFile); + const scenarioContent = fs.readFileSync(scenarioPath, "utf-8"); + const scenarioJson = JSON.parse(scenarioContent); + if (Array.isArray(scenarioJson)) { + for (const scenario of scenarioJson) { + scenarioIds.push(scenario.id); + } + } else { + scenarioIds.push(scenarioJson.id); + } } - return recordIds; + return scenarioIds; } catch (error) { console.error( - `Error reading records for provider ${providerDir}:`, + `Error reading benchmark scenarios for provider ${providerDir}:`, error.message, ); return []; diff --git a/src/env.d.ts b/src/env.d.ts index 697804b..c2eaf05 100644 --- a/src/env.d.ts +++ b/src/env.d.ts @@ -8,6 +8,7 @@ declare namespace App { username?: string; email?: string | null; roles?: string[]; + emailDomain?: string | null; }; } } diff --git a/src/middleware.ts b/src/middleware.ts index b28fcb6..5467d2f 100644 --- a/src/middleware.ts +++ b/src/middleware.ts @@ -4,7 +4,7 @@ import { config as authConfig } from "../auth.config"; import { isFeatureEnabled } from "./lib/featureflag"; import aclMapping from "./acl-mapping.json"; -const protectedPaths = ["/api/admin/services/benchmarks.json", "/dashboard"]; +const protectedPaths = ["/api/admin/services/", "/dashboard"]; /** * Check if the request is for an API endpoint @@ -52,18 +52,17 @@ export const onRequest = defineMiddleware(async (context, next) => { const session = await getSession(context.request, authConfig); if (session?.user) { + const emailDomain = `@${session.user.email?.split("@").pop()}`; + context.locals.user = { name: session.user.name, username: session.user.username, email: session.user.email, roles: session.user.roles || [], + emailDomain, }; - const emailDomain = `@${context.locals.user.email?.split("@").pop()}`; - if ( - context.locals.user.roles?.includes("administrator") || - aclMapping.acl.admin.includes(emailDomain) - ) { + if (context.locals.user.roles?.includes("administrator") || aclMapping.acl.admin.includes(emailDomain)) { return next(); } diff --git a/src/pages/api/admin/services/[id]/benchmarks.json.ts b/src/pages/api/admin/services/[id]/benchmarks.json.ts index 218600f..02ff079 100644 --- a/src/pages/api/admin/services/[id]/benchmarks.json.ts +++ b/src/pages/api/admin/services/[id]/benchmarks.json.ts @@ -5,6 +5,7 @@ import { PARQUET_MONTH_COVERAGE, getUrlsFromRequest, } from "@/lib/parquet-datasource"; +import aclMapping from "@/acl-mapping.json"; /** * @openapi @@ -92,7 +93,7 @@ import { * - Benchmark * - Scenario */ -export const GET: APIRoute = async ({ params, request }) => { +export const GET: APIRoute = async ({ params, request, locals }) => { const scenario = params.id; if (!scenario) { @@ -102,6 +103,14 @@ export const GET: APIRoute = async ({ params, request }) => { ); } + // @ts-expect-error + if (!aclMapping.records[scenario]?.includes(locals.user?.emailDomain)) { + return new Response( + JSON.stringify({ message: "Scenario not found." }), + { status: 404, headers: { "Content-Type": "application/json" } }, + ); + } + try { const urlResponse = await getUrlsFromRequest(request); if (urlResponse instanceof Response) { diff --git a/src/pages/api/admin/services/benchmarks.json.ts b/src/pages/api/admin/services/benchmarks.json.ts index 63e9563..02eba12 100644 --- a/src/pages/api/admin/services/benchmarks.json.ts +++ b/src/pages/api/admin/services/benchmarks.json.ts @@ -5,6 +5,7 @@ import { PARQUET_MONTH_COVERAGE, getUrlsFromRequest, } from "@/lib/parquet-datasource"; +import aclMapping from "@/acl-mapping.json"; /** * @openapi @@ -71,7 +72,7 @@ import { * - Admin * - Benchmark */ -export const GET: APIRoute = async ({ request }) => { +export const GET: APIRoute = async ({ request, locals }) => { try { const urlResponse = await getUrlsFromRequest(request); if (urlResponse instanceof Response) { @@ -102,7 +103,11 @@ export const GET: APIRoute = async ({ request }) => { ORDER BY "scenario_id"; `; - const data = (await executeQuery(query)) as BenchmarkSummary[]; + let data = (await executeQuery(query)) as BenchmarkSummary[]; + data = data.filter((benchmark) => { + // @ts-expect-error + return aclMapping.records[benchmark.scenario_id]?.includes(locals.user?.emailDomain) + }); return Response.json(data); } catch (error) {