Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor run cache #976

Merged
merged 9 commits into from
Jan 5, 2025
15 changes: 7 additions & 8 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ import {
AZURE_AI_INFERENCE_TOKEN_SCOPES,
MODEL_PROVIDER_AZURE_SERVERLESS_OPENAI,
DOT_ENV_FILENAME,
LARGE_MODEL_ID,
SMALL_MODEL_ID,
VISION_MODEL_ID,
} from "../../core/src/constants"
import { tryReadText } from "../../core/src/fs"
import {
Expand All @@ -41,7 +38,7 @@ import {
ModelConfiguration,
} from "../../core/src/host"
import { TraceOptions } from "../../core/src/trace"
import { deleteEmptyValues, logError, logVerbose } from "../../core/src/util"
import { logError, logVerbose } from "../../core/src/util"
import { parseModelIdentifier } from "../../core/src/models"
import { LanguageModel } from "../../core/src/chat"
import { errorMessage, NotSupportedError } from "../../core/src/error"
Expand All @@ -50,7 +47,11 @@ import { shellConfirm, shellInput, shellSelect } from "./input"
import { shellQuote } from "../../core/src/shell"
import { uniq } from "es-toolkit"
import { PLimitPromiseQueue } from "../../core/src/concurrency"
import { LanguageModelConfiguration, Project, ResponseStatus } from "../../core/src/server/messages"
import {
LanguageModelConfiguration,
Project,
ResponseStatus,
} from "../../core/src/server/messages"
import { createAzureTokenResolver } from "./azuretoken"
import {
createAzureContentSafetyClient,
Expand Down Expand Up @@ -396,15 +397,13 @@ export class NodeHost implements RuntimeHost {
}

const {
trace,
label,
cwd,
timeout = SHELL_EXEC_TIMEOUT,
stdin: input,
} = options || {}
const trace = options?.trace?.startTraceDetails(label || command)
try {
trace?.startDetails(label || command)

// python3 on windows -> python
if (command === "python3" && process.platform === "win32")
command = "python"
Expand Down
22 changes: 13 additions & 9 deletions packages/cli/src/scripts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
fixPromptDefinitions,
createScript as coreCreateScript,
} from "../../core/src/scripts"
import { logInfo, logVerbose } from "../../core/src/util"
import { deleteEmptyValues, logInfo, logVerbose } from "../../core/src/util"
import { runtimeHost } from "../../core/src/host"
import { RUNTIME_ERROR_CODE } from "../../core/src/constants"
import {
Expand All @@ -27,14 +27,18 @@ export async function listScripts(options?: ScriptFilterOptions) {
const prj = await buildProject() // Build the project to get script templates
const scripts = filterScripts(prj.scripts, options) // Filter scripts based on options
console.log(
YAMLStringify(
scripts.map(({ id, title, group, filename, system: isSystem }) => ({
id,
title,
group,
filename,
isSystem,
}))
JSON.stringify(
scripts.map(({ id, title, group, filename, system: isSystem }) =>
deleteEmptyValues({
id,
title,
group,
filename,
isSystem,
})
),
null,
2
)
)
}
Expand Down
9 changes: 6 additions & 3 deletions packages/core/bundleprompts.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ async function main() {
const promptMap = {}
const prompts = readdirSync(dir)
for (const prompt of prompts) {
if (!/\.m?js$/.test(prompt)) continue
if (!/\.mjs$/.test(prompt)) continue
const text = readFileSync(`${dir}/${prompt}`, "utf-8")
if (/\.genai\.m?js$/.test(prompt))
promptMap[prompt.replace(/\.genai\.m?js$/i, "")] = text
if (/^system\./.test(prompt)) {
const id = prompt.replace(/\.m?js$/i, "")
if (promptMap[id]) throw new Error(`duplicate prompt ${id}`)
promptMap[id] = text
}
}
console.log(`found ${Object.keys(promptMap).length} prompts`)
console.debug(Object.keys(promptMap).join("\n"))
Expand Down
41 changes: 0 additions & 41 deletions packages/core/src/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ import {

import { deleteUndefinedValues, logError, logVerbose } from "./util"
import { resolveHttpProxyAgent } from "./proxy"
import {
ChatCompletionRequestCacheKey,
getChatCompletionCache,
} from "./chatcache"
import { HttpsProxyAgent } from "https-proxy-agent"
import { MarkdownTrace } from "./trace"
import { createFetch, FetchType } from "./fetch"
Expand Down Expand Up @@ -289,40 +285,6 @@ const completerFactory = (
const { model } = parseModelIdentifier(req.model)
const { encode: encoder } = await resolveTokenEncoder(model)

const cache = !!cacheOrName || !!cacheName
const cacheStore = getChatCompletionCache(
typeof cacheOrName === "string" ? cacheOrName : cacheName
)
const cachedKey = cache
? <ChatCompletionRequestCacheKey>{
...req,
...cfgNoToken,
model: req.model,
temperature: req.temperature,
top_p: req.top_p,
max_tokens: req.max_tokens,
logit_bias: req.logit_bias,
}
: undefined
trace.itemValue(`caching`, cache)
trace.itemValue(`cache`, cacheStore?.name)
const { text: cached, finishReason: cachedFinishReason } =
(await cacheStore.get(cachedKey)) || {}
if (cached !== undefined) {
partialCb?.({
tokensSoFar: estimateTokens(cached, encoder),
responseSoFar: cached,
responseChunk: cached,
inner,
})
trace.itemValue(`cache hit`, await cacheStore.getKeySHA(cachedKey))
return {
text: cached,
finishReason: cachedFinishReason,
cached: true,
}
}

const fetch = await createFetch({
trace,
retries: retry,
Expand Down Expand Up @@ -441,9 +403,6 @@ const completerFactory = (
`${usage.total_tokens} total, ${usage.prompt_tokens} prompt, ${usage.completion_tokens} completion`
)
}

if (finishReason === "stop")
await cacheStore.set(cachedKey, { text: chatResp, finishReason })
return {
text: chatResp,
finishReason,
Expand Down
22 changes: 22 additions & 0 deletions packages/core/src/cache.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,26 @@ describe("Cache", () => {
assert.ok(sha)
assert.strictEqual(typeof sha, "string")
})
test("JSONLineCache getOrUpdate retrieves existing value", async () => {
const cache = JSONLineCache.byName<string, number>("testCache")
await cache.set("existingKey", 42)
const value = await cache.getOrUpdate(
"existingKey",
async () => 99,
() => true
)
assert.strictEqual(value.value, 42)
})

test("JSONLineCache getOrUpdate updates with new value if key does not exist", async () => {
const cache = JSONLineCache.byName<string, number>("testCache")
const value = await cache.getOrUpdate(
"newKey",
async () => 99,
() => true
)
assert.strictEqual(value.value, 99)
const cachedValue = await cache.get("newKey")
assert.strictEqual(cachedValue, 99)
})
})
29 changes: 27 additions & 2 deletions packages/core/src/cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export class MemoryCache<K, V>
implements WorkspaceFileCache<any, any>
{
protected _entries: Record<string, CacheEntry<K, V>>
private _pending: Record<string, Promise<V>>

// Constructor is private to enforce the use of byName factory method
protected constructor(public readonly name: string) {
Expand Down Expand Up @@ -53,6 +54,7 @@ export class MemoryCache<K, V>
protected async initialize() {
if (this._entries) return
this._entries = {}
this._pending = {}
}

/**
Expand Down Expand Up @@ -104,6 +106,29 @@ export class MemoryCache<K, V>
return this._entries[sha]?.val
}

async getOrUpdate(
key: K,
updater: () => Promise<V>,
validator: (val: V) => boolean
): Promise<{ key: string; value: V; cached?: boolean }> {
await this.initialize()
const sha = await keySHA(key)
if (this._entries[sha])
return { key: sha, value: this._entries[sha].val, cached: true }
if (this._pending[sha])
return { key: sha, value: await this._pending[sha], cached: true }

try {
const p = updater()
this._pending[sha] = p
const value = await p
if (validator(value)) await this.set(key, value)
return { key: sha, value, cached: false }
} finally {
delete this._pending[sha]
}
}

protected async appendEntry(entry: CacheEntry<K, V>) {}

/**
Expand Down Expand Up @@ -177,7 +202,7 @@ export class JSONLineCache<K, V> extends MemoryCache<K, V> {
*/
override async initialize() {
if (this._entries) return
this._entries = {}
super.initialize()
await host.createDirectory(this.folder()) // Ensure directory exists
const content = await tryReadText(this.path())
const objs: CacheEntry<K, V>[] = (await JSONLTryParse(content)) ?? []
Expand All @@ -201,7 +226,7 @@ export class JSONLineCache<K, V> extends MemoryCache<K, V> {
}

/**
* Compute the SHA1 hash of a key for uniqueness.
* Compute the hash of a key for uniqueness.
* Normalizes the key by converting it to a string and appending the core version.
* @param key - The key to hash
* @returns A promise resolving to the SHA256 hash string
Expand Down
Loading
Loading