diff --git a/packages/cli/src/info.ts b/packages/cli/src/info.ts index 7054076593..c34fe314b3 100644 --- a/packages/cli/src/info.ts +++ b/packages/cli/src/info.ts @@ -18,6 +18,7 @@ import { ModelConnectionInfo, resolveModelConnectionInfo, } from "../../core/src/models" +import { deleteEmptyValues } from "../../core/src/util" import { CORE_VERSION } from "../../core/src/version" import { YAMLStringify } from "../../core/src/yaml" import { buildProject } from "./build" @@ -59,6 +60,7 @@ export async function envInfo( models?: LanguageModelInfo[] } = await parseTokenFromEnv(env, `${modelProvider.id}:*`) if (conn) { + console.log(modelProvider.id + " " + conn) // Mask the token if the option is set if (!token && conn.token) conn.token = "***" if (models) { @@ -68,7 +70,15 @@ export async function envInfo( if (ms?.length) conn.models = ms } } - res.providers.push(conn) + res.providers.push( + deleteEmptyValues({ + provider: conn.provider, + source: conn.source, + base: conn.base, + type: conn.type, + models: conn.models, + }) + ) } } catch (e) { if (error) diff --git a/packages/core/src/lm.ts b/packages/core/src/lm.ts index 78546a62a6..69bc530edb 100644 --- a/packages/core/src/lm.ts +++ b/packages/core/src/lm.ts @@ -7,12 +7,13 @@ import { MODEL_PROVIDER_ANTHROPIC, MODEL_PROVIDER_ANTHROPIC_BEDROCK, MODEL_PROVIDER_CLIENT, + MODEL_PROVIDER_JAN, MODEL_PROVIDER_OLLAMA, MODEL_PROVIDER_TRANSFORMERS, } from "./constants" import { host } from "./host" import { OllamaModel } from "./ollama" -import { OpenAIModel } from "./openai" +import { OpenAIModel, LocalOpenAICompatibleModel } from "./openai" import { TransformersModel } from "./transformers" export function resolveLanguageModel(provider: string): LanguageModel { @@ -27,5 +28,9 @@ export function resolveLanguageModel(provider: string): LanguageModel { if (provider === MODEL_PROVIDER_ANTHROPIC_BEDROCK) return AnthropicBedrockModel if (provider === MODEL_PROVIDER_TRANSFORMERS) return TransformersModel + + if (provider === MODEL_PROVIDER_JAN) + return LocalOpenAICompatibleModel(MODEL_PROVIDER_JAN) + return OpenAIModel } diff --git a/packages/core/src/openai.ts b/packages/core/src/openai.ts index 73b324cdc6..0be4acd89e 100644 --- a/packages/core/src/openai.ts +++ b/packages/core/src/openai.ts @@ -1,5 +1,6 @@ import { deleteUndefinedValues, + logError, logVerbose, normalizeInt, trimTrailingSlash, @@ -18,9 +19,14 @@ import { TOOL_URL, } from "./constants" import { estimateTokens } from "./tokens" -import { ChatCompletionHandler, LanguageModel, LanguageModelInfo } from "./chat" +import { + ChatCompletionHandler, + LanguageModel, + LanguageModelInfo, + PullModelFunction, +} from "./chat" import { RequestError, errorMessage, serializeError } from "./error" -import { createFetch, traceFetchPost } from "./fetch" +import { createFetch, iterateBody, traceFetchPost } from "./fetch" import { parseModelIdentifier } from "./models" import { JSON5TryParse } from "./json5" import { @@ -449,8 +455,69 @@ async function listModels( ) } +const pullModel: PullModelFunction = async (modelId, options) => { + const { trace, cancellationToken } = options || {} + const { provider, model } = parseModelIdentifier(modelId) + const fetch = await createFetch({ retries: 0, ...options }) + const conn = await host.getLanguageModelConfiguration(modelId, { + token: true, + cancellationToken, + trace, + }) + try { + // test if model is present + const resTags = await fetch(`${conn.base}/models`, { + retries: 0, + method: "GET", + headers: { + "User-Agent": TOOL_ID, + "Content-Type": "application/json", + }, + }) + if (resTags.ok) { + const { models }: { models: { model: string }[] } = + await resTags.json() + if (models.find((m) => m.model === model)) return { ok: true } + } + + // pull + logVerbose(`${provider}: pull ${model}`) + const resPull = await fetch(`${conn.base}/models/pull`, { + method: "POST", + headers: { + "User-Agent": TOOL_ID, + "Content-Type": "application/json", + }, + body: JSON.stringify({ model }), + }) + if (!resPull.ok) { + logError(`${provider}: failed to pull model ${model}`) + logVerbose(resPull.statusText) + return { ok: false, status: resPull.status } + } + 0 + for await (const chunk of iterateBody(resPull, { cancellationToken })) + process.stderr.write(".") + process.stderr.write("\n") + return { ok: true } + } catch (e) { + logError(e) + trace.error(e) + return { ok: false, error: serializeError(e) } + } +} + export const OpenAIModel = Object.freeze({ completer: OpenAIChatCompletion, id: MODEL_PROVIDER_OPENAI, listModels, }) + +export function LocalOpenAICompatibleModel(providerId: string) { + return Object.freeze({ + completer: OpenAIChatCompletion, + id: providerId, + listModels, + pullModel, + }) +}