Skip to content

Commit

Permalink
Refactor Prompt Node Options Structure (#967)
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan authored Dec 25, 2024
1 parent 83d87b7 commit 7090d29
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 39 deletions.
27 changes: 12 additions & 15 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1087,20 +1087,17 @@ export function tracePromptResult(
export function appendUserMessage(
messages: ChatCompletionMessageParam[],
content: string,
options?: { ephemeral?: boolean }
options?: ContextExpansionOptions
) {
if (!content) return
const { ephemeral } = options || {}
const { cacheControl } = options || {}
let last = messages.at(-1) as ChatCompletionUserMessageParam
if (
last?.role !== "user" ||
!!ephemeral !== (last?.cacheControl === "ephemeral")
) {
if (last?.role !== "user" || options?.cacheControl !== last?.cacheControl) {
last = {
role: "user",
content: "",
} satisfies ChatCompletionUserMessageParam
if (ephemeral) last.cacheControl = "ephemeral"
if (cacheControl) last.cacheControl = cacheControl
messages.push(last)
}
if (last.content) {
Expand All @@ -1112,20 +1109,20 @@ export function appendUserMessage(
export function appendAssistantMessage(
messages: ChatCompletionMessageParam[],
content: string,
options?: { ephemeral?: boolean }
options?: ContextExpansionOptions
) {
if (!content) return
const { ephemeral } = options || {}
const { cacheControl } = options || {}
let last = messages.at(-1) as ChatCompletionAssistantMessageParam
if (
last?.role !== "assistant" ||
!!ephemeral !== (last?.cacheControl === "ephemeral")
options?.cacheControl !== last?.cacheControl
) {
last = {
role: "assistant",
content: "",
} satisfies ChatCompletionAssistantMessageParam
if (ephemeral) last.cacheControl = "ephemeral"
if (cacheControl) last.cacheControl = cacheControl
messages.push(last)
}
if (last.content) {
Expand All @@ -1137,21 +1134,21 @@ export function appendAssistantMessage(
export function appendSystemMessage(
messages: ChatCompletionMessageParam[],
content: string,
options?: { ephemeral?: boolean }
options?: ContextExpansionOptions
) {
if (!content) return
const { ephemeral } = options || {}
const { cacheControl } = options || {}

let last = messages[0] as ChatCompletionSystemMessageParam
if (
last?.role !== "system" ||
!!ephemeral !== (last?.cacheControl === "ephemeral")
options?.cacheControl !== last?.cacheControl
) {
last = {
role: "system",
content: "",
} as ChatCompletionSystemMessageParam
if (ephemeral) last.cacheControl = "ephemeral"
if (cacheControl) last.cacheControl = cacheControl
messages.unshift(last)
}
if (last.content) {
Expand Down
40 changes: 22 additions & 18 deletions packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ export interface PromptNode extends ContextExpansionOptions {
children?: PromptNode[] // Child nodes for hierarchical structure
error?: unknown // Error information if present
tokens?: number // Token count for the node
/**
* Definte a prompt caching breakpoint.
* This prompt prefix (including this text) is cacheable for a short amount of time.
*/
ephemeral?: boolean

/**
* Rendered markdown preview of the node
Expand Down Expand Up @@ -237,6 +232,15 @@ export function createDef(
return { type: "def", name, value, ...(options || {}) }
}

function cloneContextFields(n: PromptNode): Partial<PromptNode> {
const r = {} as Partial<PromptNode>
r.maxTokens = n.maxTokens
r.priority = n.priority
r.flex = n.flex
r.cacheControl = n.cacheControl
return r
}

export function createDefDiff(
name: string,
left: string | WorkspaceFile,
Expand Down Expand Up @@ -327,7 +331,7 @@ function renderDefNode(def: PromptDefNode): string {
}

async function renderDefDataNode(n: PromptDefDataNode): Promise<string> {
const { name, headers, priority, ephemeral, query } = n
const { name, headers, priority, cacheControl, query } = n
let data = n.resolved
let format = n.format
if (
Expand Down Expand Up @@ -680,7 +684,7 @@ async function resolvePromptNode(
const rendered = renderDefNode(n)
n.preview = rendered
n.tokens = estimateTokens(rendered, encoder)
n.children = [createTextNode(rendered)]
n.children = [createTextNode(rendered, cloneContextFields(n))]
} catch (e) {
n.error = e
}
Expand All @@ -693,7 +697,7 @@ async function resolvePromptNode(
const rendered = await renderDefDataNode(n)
n.preview = rendered
n.tokens = estimateTokens(rendered, encoder)
n.children = [createTextNode(rendered)]
n.children = [createTextNode(rendered, cloneContextFields(n))]
} catch (e) {
n.error = e
}
Expand Down Expand Up @@ -929,7 +933,7 @@ async function truncatePromptNode(
n.tokens = estimateTokens(n.resolved.content, encoder)
const rendered = renderDefNode(n)
n.preview = rendered
n.children = [createTextNode(rendered)]
n.children = [createTextNode(rendered, cloneContextFields(n))]
truncated = true
trace.log(
`truncated def ${n.name} to ${n.tokens} tokens (max ${n.maxTokens})`
Expand Down Expand Up @@ -1062,14 +1066,14 @@ async function validateSafetyPromptNode(
def: async (n) => {
if (!n.detectPromptInjection || !n.resolved?.content) return

const detectPromptInjection = await resolveContentSafety()
const detectPromptInjectionFn = await resolveContentSafety()
if (
(!detectPromptInjection && n.detectPromptInjection === true) ||
(!detectPromptInjectionFn && n.detectPromptInjection === true) ||
n.detectPromptInjection === "always"
)
throw new Error("content safety service not available")
const { attackDetected } =
(await detectPromptInjection?.(n.resolved)) || {}
(await detectPromptInjectionFn?.(n.resolved)) || {}
if (attackDetected) {
mod = true
n.resolved = {
Expand All @@ -1087,14 +1091,14 @@ async function validateSafetyPromptNode(
defData: async (n) => {
if (!n.detectPromptInjection || !n.preview) return

const detectPromptInjection = await resolveContentSafety()
const detectPromptInjectionFn = await resolveContentSafety()
if (
(!detectPromptInjection && n.detectPromptInjection === true) ||
(!detectPromptInjectionFn && n.detectPromptInjection === true) ||
n.detectPromptInjection === "always"
)
throw new Error("content safety service not available")
const { attackDetected } =
(await detectPromptInjection?.(n.preview)) || {}
(await detectPromptInjectionFn?.(n.preview)) || {}
if (attackDetected) {
mod = true
n.children = []
Expand Down Expand Up @@ -1164,13 +1168,13 @@ export async function renderPromptNode(
if (safety) await tracePromptNode(trace, node, { label: "safety" })

const messages: ChatCompletionMessageParam[] = []
const appendSystem = (content: string, options: { ephemeral?: boolean }) =>
const appendSystem = (content: string, options: ContextExpansionOptions) =>
appendSystemMessage(messages, content, options)
const appendUser = (content: string, options: { ephemeral?: boolean }) =>
const appendUser = (content: string, options: ContextExpansionOptions) =>
appendUserMessage(messages, content, options)
const appendAssistant = (
content: string,
options: { ephemeral?: boolean }
options: ContextExpansionOptions
) => appendAssistantMessage(messages, content, options)

const images: PromptImage[] = []
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ export function createChatTurnGenerationContext(
return res
},
cacheControl: (cc) => {
current.ephemeral = cc === "ephemeral"
current.cacheControl = cc
return res
},
} satisfies PromptTemplateString)
Expand Down
5 changes: 0 additions & 5 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -951,11 +951,6 @@ interface ContextExpansionOptions {
*/
flex?: number

/**
* @deprecated use cacheControl instead
*/
ephemeral?: boolean

/**
* Caching policy for this text. `ephemeral` means the prefix can be cached for a short amount of time.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
script({
title: "summarize all files with caching",
files: "src/rag/markdown.md",
model: "small",
tests: [
{
files: "src/rag/markdown.md",
Expand Down

0 comments on commit 7090d29

Please sign in to comment.