Skip to content

Commit 9cecf25

Browse files
feat: add MS Defender option
1 parent 689bc5a commit 9cecf25

File tree

3 files changed

+121
-0
lines changed

3 files changed

+121
-0
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,20 @@ There are multiple ways to run this sample: locally using Ollama or Azure OpenAI
114114

115115
See the [cost estimation](./docs/cost.md) details for running this sample on Azure.
116116

117+
#### (Optional) Enable additional user context to Microsoft Defender for Cloud
118+
In case you have Microsoft Defender for Cloud protection on your Azure OpenAI resource and you want to have additional user context on the alerts, run this command:
119+
120+
```bash
121+
azd env set MS_DEFENDER_ENABLED true
122+
```
123+
124+
To customize the application name of the user context, run this command:
125+
```bash
126+
azd env set APPLICATION_NAME <your application name>
127+
```
128+
129+
For more details, refer to the [Microsoft Defender for Cloud documentation](https://learn.microsoft.com/azure/defender-for-cloud/gain-end-user-context-ai).
130+
117131
#### Deploy the sample
118132

119133
1. Open a terminal and navigate to the root of the project.

packages/api/src/functions/chat-post.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import {
1010
import { AzureOpenAI, OpenAI } from 'openai';
1111
import 'dotenv/config';
1212
import { type ChatCompletionChunk } from 'openai/resources';
13+
import { getMsDefenderUserJson, type UserSecurityContext } from './security/ms-defender-utils.js';
1314

1415
const azureOpenAiScope = 'https://cognitiveservices.azure.com/.default';
1516
const systemPrompt = `Assistant helps the user with cooking questions. Be brief in your answers. Answer only plain text, DO NOT use Markdown.
@@ -61,12 +62,19 @@ export async function postChat(
6162
throw new Error('No OpenAI API key or Azure OpenAI deployment provided');
6263
}
6364

65+
let userSecurityContext: UserSecurityContext | undefined;
66+
if (process.env.MS_DEFENDER_ENABLED) {
67+
userSecurityContext = getMsDefenderUserJson(request);
68+
}
69+
6470
if (stream) {
71+
// @ts-expect-error user_security_context field is unsupported via openai client
6572
const responseStream = await openai.chat.completions.create({
6673
messages: [{ role: 'system', content: systemPrompt }, ...messages],
6774
temperature: 0.7,
6875
model,
6976
stream: true,
77+
user_security_context: userSecurityContext,
7078
});
7179
const jsonStream = Readable.from(createJsonStream(responseStream));
7280

@@ -83,6 +91,8 @@ export async function postChat(
8391
messages: [{ role: 'system', content: systemPrompt }, ...messages],
8492
temperature: 0.7,
8593
model,
94+
// @ts-expect-error user_security_context field is unsupported via openai client
95+
user_security_context: userSecurityContext,
8696
});
8797

8898
return {
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import process from 'node:process';
2+
import { type HttpRequest } from '@azure/functions';
3+
4+
/**
5+
* Generates the user security context which contains several parameters that describe the AI application itself, and the end user that interacts with the AI application.
6+
* These fields assist your security operations teams to investigate and mitigate security incidents by providing a comprehensive approach to protecting your AI applications.
7+
* [Learn more](https://learn.microsoft.com/azure/defender-for-cloud/gain-end-user-context-ai) about protecting AI applications using Microsoft Defender for Cloud.
8+
* @param request - The HTTP request
9+
* @returns A json string which represents the user context
10+
*/
11+
export function getMsDefenderUserJson(request: HttpRequest): UserSecurityContext {
12+
const sourceIp = getSourceIp(request);
13+
const authenticatedUserDetails = getAuthenticatedUserDetails(request);
14+
15+
const userSecurityContext = {
16+
end_user_tenant_id: authenticatedUserDetails.get('tenantId'),
17+
end_user_id: authenticatedUserDetails.get('userId'),
18+
source_ip: sourceIp,
19+
application_name: process.env.APPLICATION_NAME,
20+
} as UserSecurityContext;
21+
22+
return userSecurityContext;
23+
}
24+
25+
/**
26+
* Extracts user authentication details from the 'X-Ms-Client-Principal' header.
27+
* This is based on [Azure Static Web App documentation](https://learn.microsoft.com/en-us/azure/static-web-apps/user-information)
28+
* @param request - The HTTP request
29+
* @returns A dictionary containing authentication details of the user
30+
*/
31+
function getAuthenticatedUserDetails(request: HttpRequest): Map<string, string> {
32+
const authenticatedUserDetails = new Map<string, string>();
33+
const principalHeader = request.headers.get('X-Ms-Client-Principal');
34+
if (principalHeader === null) {
35+
return authenticatedUserDetails;
36+
}
37+
38+
const principal = parsePrincipal(principalHeader);
39+
if (principal === null) {
40+
return authenticatedUserDetails;
41+
}
42+
43+
const tenantId = process.env.AZURE_TENANT_ID;
44+
if (principal!.identityProvider === 'aad') {
45+
// TODO: add only when userId represents actual IDP user id
46+
// authenticatedUserDetails.set('userId', principal['userId']);
47+
authenticatedUserDetails.set('tenantId', tenantId!);
48+
}
49+
50+
return authenticatedUserDetails;
51+
}
52+
53+
function parsePrincipal(principal: string | undefined): Principal | undefined {
54+
if (principal === null) {
55+
return undefined;
56+
}
57+
58+
try {
59+
return JSON.parse(Buffer.from(principal!, 'base64').toString('utf8')) as Principal;
60+
} catch {
61+
return undefined;
62+
}
63+
}
64+
65+
function getSourceIp(request: HttpRequest) {
66+
const xForwardFor = request.headers.get('X-Forwarded-For');
67+
if (xForwardFor === null) {
68+
return null;
69+
}
70+
71+
const ip = xForwardFor.split(',')[0];
72+
const colonIndex = ip.lastIndexOf(':');
73+
74+
// Case of ipv4
75+
if (colonIndex !== -1 && ip.indexOf(':') === colonIndex) {
76+
return ip.slice(0, Math.max(0, colonIndex));
77+
}
78+
79+
// Case of ipv6
80+
if (ip.startsWith('[') && ip.includes(']:')) {
81+
return ip.slice(0, Math.max(0, ip.indexOf(']:') + 1));
82+
}
83+
84+
return ip;
85+
}
86+
87+
type Principal = {
88+
identityProvider: string;
89+
userId: string;
90+
};
91+
92+
export type UserSecurityContext = {
93+
application_name: string;
94+
end_user_id: string;
95+
end_user_tenant_id: string;
96+
source_ip: string;
97+
};

0 commit comments

Comments
 (0)