Skip to content

Commit

Permalink
feat: ✨ add support for pulling local OpenAI-compatible models
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Dec 26, 2024
1 parent f47877a commit 60ea681
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
12 changes: 11 additions & 1 deletion packages/cli/src/info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
ModelConnectionInfo,
resolveModelConnectionInfo,
} from "../../core/src/models"
import { deleteEmptyValues } from "../../core/src/util"
import { CORE_VERSION } from "../../core/src/version"
import { YAMLStringify } from "../../core/src/yaml"
import { buildProject } from "./build"
Expand Down Expand Up @@ -59,6 +60,7 @@ export async function envInfo(
models?: LanguageModelInfo[]
} = await parseTokenFromEnv(env, `${modelProvider.id}:*`)
if (conn) {
console.log(modelProvider.id + " " + conn)
// Mask the token if the option is set
if (!token && conn.token) conn.token = "***"
if (models) {
Expand All @@ -68,7 +70,15 @@ export async function envInfo(
if (ms?.length) conn.models = ms
}
}
res.providers.push(conn)
res.providers.push(
deleteEmptyValues({
provider: conn.provider,
source: conn.source,
base: conn.base,
type: conn.type,
models: conn.models,
})
)
}
} catch (e) {
if (error)
Expand Down
7 changes: 6 additions & 1 deletion packages/core/src/lm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ import {
MODEL_PROVIDER_ANTHROPIC,
MODEL_PROVIDER_ANTHROPIC_BEDROCK,
MODEL_PROVIDER_CLIENT,
MODEL_PROVIDER_JAN,
MODEL_PROVIDER_OLLAMA,
MODEL_PROVIDER_TRANSFORMERS,
} from "./constants"
import { host } from "./host"
import { OllamaModel } from "./ollama"
import { OpenAIModel } from "./openai"
import { OpenAIModel, LocalOpenAICompatibleModel } from "./openai"
import { TransformersModel } from "./transformers"

export function resolveLanguageModel(provider: string): LanguageModel {
Expand All @@ -27,5 +28,9 @@ export function resolveLanguageModel(provider: string): LanguageModel {
if (provider === MODEL_PROVIDER_ANTHROPIC_BEDROCK)
return AnthropicBedrockModel
if (provider === MODEL_PROVIDER_TRANSFORMERS) return TransformersModel

if (provider === MODEL_PROVIDER_JAN)
return LocalOpenAICompatibleModel(MODEL_PROVIDER_JAN)

return OpenAIModel
}
71 changes: 69 additions & 2 deletions packages/core/src/openai.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
deleteUndefinedValues,
logError,
logVerbose,
normalizeInt,
trimTrailingSlash,
Expand All @@ -18,9 +19,14 @@ import {
TOOL_URL,
} from "./constants"
import { estimateTokens } from "./tokens"
import { ChatCompletionHandler, LanguageModel, LanguageModelInfo } from "./chat"
import {
ChatCompletionHandler,
LanguageModel,
LanguageModelInfo,
PullModelFunction,
} from "./chat"
import { RequestError, errorMessage, serializeError } from "./error"
import { createFetch, traceFetchPost } from "./fetch"
import { createFetch, iterateBody, traceFetchPost } from "./fetch"
import { parseModelIdentifier } from "./models"
import { JSON5TryParse } from "./json5"
import {
Expand Down Expand Up @@ -449,8 +455,69 @@ async function listModels(
)
}

const pullModel: PullModelFunction = async (modelId, options) => {
const { trace, cancellationToken } = options || {}
const { provider, model } = parseModelIdentifier(modelId)
const fetch = await createFetch({ retries: 0, ...options })
const conn = await host.getLanguageModelConfiguration(modelId, {
token: true,
cancellationToken,
trace,
})
try {
// test if model is present
const resTags = await fetch(`${conn.base}/models`, {
retries: 0,
method: "GET",
headers: {
"User-Agent": TOOL_ID,
"Content-Type": "application/json",
},
})
if (resTags.ok) {
const { models }: { models: { model: string }[] } =
await resTags.json()
if (models.find((m) => m.model === model)) return { ok: true }
}

// pull
logVerbose(`${provider}: pull ${model}`)
const resPull = await fetch(`${conn.base}/models/pull`, {
method: "POST",
headers: {
"User-Agent": TOOL_ID,
"Content-Type": "application/json",
},
body: JSON.stringify({ model }),
})
if (!resPull.ok) {
logError(`${provider}: failed to pull model ${model}`)
logVerbose(resPull.statusText)
return { ok: false, status: resPull.status }
}
0
for await (const chunk of iterateBody(resPull, { cancellationToken }))
process.stderr.write(".")
process.stderr.write("\n")
return { ok: true }
} catch (e) {
logError(e)
trace.error(e)
return { ok: false, error: serializeError(e) }
}
}

export const OpenAIModel = Object.freeze<LanguageModel>({
completer: OpenAIChatCompletion,
id: MODEL_PROVIDER_OPENAI,
listModels,
})

export function LocalOpenAICompatibleModel(providerId: string) {
return Object.freeze<LanguageModel>({
completer: OpenAIChatCompletion,
id: providerId,
listModels,
pullModel,
})
}

0 comments on commit 60ea681

Please sign in to comment.