diff --git a/src/app/(dashboard)/dashboard/providers/[id]/page.js b/src/app/(dashboard)/dashboard/providers/[id]/page.js index b9077e480..f7d2bb77d 100644 --- a/src/app/(dashboard)/dashboard/providers/[id]/page.js +++ b/src/app/(dashboard)/dashboard/providers/[id]/page.js @@ -40,6 +40,7 @@ export default function ProviderDetailPage() { const [thinkingMode, setThinkingMode] = useState("auto"); const [suggestedModels, setSuggestedModels] = useState([]); const [kiloFreeModels, setKiloFreeModels] = useState([]); + const [togglingModelId, setTogglingModelId] = useState(null); const { copied, copy } = useCopyToClipboard(); const providerInfo = providerNode @@ -563,6 +564,70 @@ export default function ProviderDetailPage() { } }; + // All connections for a provider carry identical disabledModels (Task 1 invariant), so first is representative. + const disabledModels = connections[0]?.providerSpecificData?.disabledModels || []; + const disabledModelsSet = new Set(disabledModels); + + const handleDisableModel = async (modelId) => { + if (togglingModelId) return; + const connectionId = connections.find((c) => c.isActive !== false)?.id || connections[0]?.id; + if (!connectionId) return; + setTogglingModelId(modelId); + const next = [...new Set([...disabledModels, modelId])]; + setConnections((prev) => + prev.map((c) => ({ + ...c, + providerSpecificData: { ...(c.providerSpecificData || {}), disabledModels: next }, + })) + ); + try { + const res = await fetch(`/api/providers/${connectionId}`, { + method: "PATCH", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ disableModel: modelId }), + }); + if (!res.ok) { + console.log("Error disabling model:", res.status); + } + await fetchConnections(); + } catch (error) { + console.log("Error disabling model:", error); + await fetchConnections(); + } finally { + setTogglingModelId(null); + } + }; + + const handleEnableModel = async (modelId) => { + if (togglingModelId) return; + const connectionId = connections.find((c) => c.isActive !== false)?.id || connections[0]?.id; + if (!connectionId) return; + setTogglingModelId(modelId); + const next = disabledModels.filter((m) => m !== modelId); + setConnections((prev) => + prev.map((c) => ({ + ...c, + providerSpecificData: { ...(c.providerSpecificData || {}), disabledModels: next }, + })) + ); + try { + const res = await fetch(`/api/providers/${connectionId}`, { + method: "PATCH", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ enableModel: modelId }), + }); + if (!res.ok) { + console.log("Error enabling model:", res.status); + } + await fetchConnections(); + } catch (error) { + console.log("Error enabling model:", error); + await fetchConnections(); + } finally { + setTogglingModelId(null); + } + }; + const renderModelsSection = () => { if (isCompatible) { return ( @@ -576,6 +641,10 @@ export default function ProviderDetailPage() { onDeleteAlias={handleDeleteAlias} connections={connections} isAnthropic={isAnthropicCompatible} + disabledModels={disabledModels} + onDisableModel={connections.length > 0 ? handleDisableModel : undefined} + onEnableModel={connections.length > 0 ? handleEnableModel : undefined} + togglingModelId={togglingModelId} /> ); } @@ -624,6 +693,10 @@ export default function ProviderDetailPage() { onTest={connections.length > 0 || isFreeNoAuth ? () => handleTestModel(model.id) : undefined} isTesting={testingModelId === model.id} isFree={model.isFree} + isDisabled={disabledModelsSet.has(model.id)} + onDisable={connections.length > 0 ? () => handleDisableModel(model.id) : undefined} + onEnable={connections.length > 0 ? () => handleEnableModel(model.id) : undefined} + isToggling={togglingModelId === model.id} /> ); })} @@ -644,6 +717,10 @@ export default function ProviderDetailPage() { isTesting={testingModelId === model.id} isCustom isFree={false} + isDisabled={disabledModelsSet.has(model.id)} + onDisable={connections.length > 0 ? () => handleDisableModel(model.id) : undefined} + onEnable={connections.length > 0 ? () => handleEnableModel(model.id) : undefined} + isToggling={togglingModelId === model.id} /> ))} @@ -1045,33 +1122,39 @@ export default function ProviderDetailPage() { ); } -function ModelRow({ model, fullModel, alias, copied, onCopy, testStatus, isCustom, isFree, onDeleteAlias, onTest, isTesting }) { - const borderColor = testStatus === "ok" +function ModelRow({ model, fullModel, alias, copied, onCopy, testStatus, isCustom, isFree, onDeleteAlias, onTest, isTesting, isDisabled, onDisable, onEnable, isToggling }) { + const borderColor = isDisabled + ? "border-black/[0.06] dark:border-white/[0.06]" + : testStatus === "ok" ? "border-green-500/40" : testStatus === "error" ? "border-red-500/40" : "border-border"; - const iconColor = testStatus === "ok" + const iconColor = isDisabled + ? undefined + : testStatus === "ok" ? "#22c55e" : testStatus === "error" ? "#ef4444" : undefined; return ( -
+
- {testStatus === "ok" ? "check_circle" : testStatus === "error" ? "cancel" : "smart_toy"} + {isDisabled ? "block" : testStatus === "ok" ? "check_circle" : testStatus === "error" ? "cancel" : "smart_toy"} -
- {fullModel} - {model.name && {model.name}} -
- {onTest && ( + {fullModel} + {isDisabled && ( + + disabled + + )} + {!isDisabled && onTest && (
- {isCustom && ( + {(onDisable || onEnable) && ( +
+ + + {isDisabled ? "Enable" : "Disable"} + +
+ )} + {isCustom && !isDisabled && (
- {onTest && ( + {!isDisabled && onTest && (
- {/* Delete button */} + {(onDisable || onEnable) && ( +
+ + + {isDisabled ? "Enable" : "Disable"} + +
+ )} +
@@ -1485,6 +1623,10 @@ CompatibleModelsSection.propTypes = { isActive: PropTypes.bool, })).isRequired, isAnthropic: PropTypes.bool, + disabledModels: PropTypes.arrayOf(PropTypes.string), + onDisableModel: PropTypes.func, + onEnableModel: PropTypes.func, + togglingModelId: PropTypes.string, }; function CooldownTimer({ until }) { @@ -2207,4 +2349,3 @@ AddCustomModelModal.propTypes = { onSave: PropTypes.func.isRequired, onClose: PropTypes.func.isRequired, }; - diff --git a/src/app/api/providers/[id]/route.js b/src/app/api/providers/[id]/route.js index 6ab517975..6d45092b3 100644 --- a/src/app/api/providers/[id]/route.js +++ b/src/app/api/providers/[id]/route.js @@ -4,6 +4,7 @@ import { getProxyPoolById, updateProviderConnection, deleteProviderConnection, + updateProviderDisabledModels, } from "@/models"; function normalizeProxyConfig(body = {}) { @@ -187,3 +188,97 @@ export async function DELETE(request, { params }) { return NextResponse.json({ error: "Failed to delete connection" }, { status: 500 }); } } + +// PATCH /api/providers/[id] - Provider-wide model disable mutations +// Body (exactly one variant): +// { disabledModels: string[] } — replace the full provider-wide disabled list +// { disableModel: string } — idempotent add of a single bare model ID +// { enableModel: string } — remove a single bare model ID from the disabled list +export async function PATCH(request, { params }) { + try { + const { id } = await params; + const body = await request.json(); + + const connection = await getProviderConnectionById(id); + if (!connection) { + return NextResponse.json({ error: "Connection not found" }, { status: 404 }); + } + + const providerId = connection.provider; + if (!providerId) { + return NextResponse.json({ error: "Connection has no provider" }, { status: 400 }); + } + + // Detect which variant(s) the caller provided + const hasDisabledModels = Object.prototype.hasOwnProperty.call(body, "disabledModels"); + const hasDisableModel = Object.prototype.hasOwnProperty.call(body, "disableModel"); + const hasEnableModel = Object.prototype.hasOwnProperty.call(body, "enableModel"); + const variantCount = [hasDisabledModels, hasDisableModel, hasEnableModel].filter(Boolean).length; + + if (variantCount === 0) { + return NextResponse.json( + { error: "Body must contain exactly one of: disabledModels (array), disableModel (string), enableModel (string)" }, + { status: 400 } + ); + } + if (variantCount > 1) { + return NextResponse.json( + { error: "Body must contain exactly one of: disabledModels, disableModel, enableModel — not multiple" }, + { status: 400 } + ); + } + + // Current provider-wide disabled list (any one connection is representative per Task 1 invariant) + const currentDisabled = Array.isArray(connection.providerSpecificData?.disabledModels) + ? [...connection.providerSpecificData.disabledModels] + : []; + + let nextDisabled; + + if (hasDisabledModels) { + if (!Array.isArray(body.disabledModels)) { + return NextResponse.json({ error: "disabledModels must be an array" }, { status: 400 }); + } + // Trim, reject blanks, deduplicate + nextDisabled = [...new Set( + body.disabledModels + .filter((m) => typeof m === "string") + .map((m) => m.trim()) + .filter((m) => m) + )]; + } else if (hasDisableModel) { + if (typeof body.disableModel !== "string") { + return NextResponse.json({ error: "disableModel must be a string" }, { status: 400 }); + } + const modelId = body.disableModel.trim(); + if (!modelId) { + return NextResponse.json({ error: "disableModel must not be blank" }, { status: 400 }); + } + // Idempotent add + nextDisabled = currentDisabled.includes(modelId) + ? currentDisabled + : [...currentDisabled, modelId]; + } else { + // hasEnableModel + if (typeof body.enableModel !== "string") { + return NextResponse.json({ error: "enableModel must be a string" }, { status: 400 }); + } + const modelId = body.enableModel.trim(); + if (!modelId) { + return NextResponse.json({ error: "enableModel must not be blank" }, { status: 400 }); + } + nextDisabled = currentDisabled.filter((m) => m !== modelId); + } + + const updatedCount = await updateProviderDisabledModels(providerId, nextDisabled); + + return NextResponse.json({ + providerId, + disabledModels: nextDisabled, + updatedConnections: updatedCount, + }); + } catch (error) { + console.log("Error updating provider disabled models:", error); + return NextResponse.json({ error: "Failed to update provider disabled models" }, { status: 500 }); + } +} diff --git a/src/app/api/v1/models/route.js b/src/app/api/v1/models/route.js index 5017eac51..bd641f16b 100644 --- a/src/app/api/v1/models/route.js +++ b/src/app/api/v1/models/route.js @@ -156,6 +156,11 @@ export async function GET() { ).trim(); const providerModels = PROVIDER_MODELS[staticAlias] || []; const enabledModels = conn?.providerSpecificData?.enabledModels; + const disabledModelsSet = new Set( + Array.isArray(conn?.providerSpecificData?.disabledModels) + ? conn.providerSpecificData.disabledModels + : [], + ); const hasExplicitEnabledModels = Array.isArray(enabledModels) && enabledModels.length > 0; const isCompatibleProvider = @@ -191,7 +196,8 @@ export async function GET() { } return modelId; }) - .filter((modelId) => typeof modelId === "string" && modelId.trim() !== ""); + .filter((modelId) => typeof modelId === "string" && modelId.trim() !== "") + .filter((modelId) => !disabledModelsSet.has(modelId)); for (const modelId of modelIds) { models.push({ diff --git a/src/lib/localDb.js b/src/lib/localDb.js index 08848083f..e2c45e90f 100644 --- a/src/lib/localDb.js +++ b/src/lib/localDb.js @@ -742,6 +742,39 @@ export async function reorderProviderConnections(providerId) { await safeWrite(db); } +/** + * Update provider-wide disabled models list. + * Writes disabledModels (array of bare model IDs) into providerSpecificData + * for ALL connections belonging to the given provider (active and inactive), then does a + * single safeWrite. + * + * @param {string} providerId - provider identifier (e.g. "gemini", "openai") + * @param {string[]} disabledModels - bare model IDs to disable, e.g. ["gemini-2.0-flash"] + * @returns {Promise} number of connections that were updated + */ +export async function updateProviderDisabledModels(providerId, disabledModels) { + const db = await getDb(); + const now = new Date().toISOString(); + let updatedCount = 0; + + db.data.providerConnections.forEach((c, index) => { + if (c.provider !== providerId) return; + + db.data.providerConnections[index] = { + ...c, + providerSpecificData: { + ...(c.providerSpecificData || {}), + disabledModels: Array.isArray(disabledModels) ? [...disabledModels] : [], + }, + updatedAt: now, + }; + updatedCount++; + }); + + await safeWrite(db); + return updatedCount; +} + // ============ Model Aliases ============ /** diff --git a/src/models/index.js b/src/models/index.js index e61129fe2..6664bdcca 100644 --- a/src/models/index.js +++ b/src/models/index.js @@ -5,6 +5,7 @@ export { createProviderConnection, updateProviderConnection, deleteProviderConnection, + updateProviderDisabledModels, getProviderNodes, getProviderNodeById, createProviderNode, diff --git a/src/shared/components/ModelSelectModal.js b/src/shared/components/ModelSelectModal.js index e56ae09ac..5de81de91 100644 --- a/src/shared/components/ModelSelectModal.js +++ b/src/shared/components/ModelSelectModal.js @@ -88,6 +88,12 @@ export default function ModelSelectModal({ const alias = PROVIDER_ID_TO_ALIAS[providerId] || providerId; const providerInfo = allProviders[providerId] || { name: providerId, color: "#666" }; const isCustomProvider = isOpenAICompatibleProvider(providerId) || isAnthropicCompatibleProvider(providerId); + const connection = activeProviders.find(p => p.provider === providerId); + const disabledModels = new Set( + Array.isArray(connection?.providerSpecificData?.disabledModels) + ? connection.providerSpecificData.disabledModels + : [] + ); if (providerInfo.passthroughModels) { const aliasModels = Object.entries(modelAliases) @@ -96,7 +102,8 @@ export default function ModelSelectModal({ id: fullModel.replace(`${alias}/`, ""), name: aliasName, value: fullModel, - })); + })) + .filter((model) => !disabledModels.has(model.id)); if (aliasModels.length > 0) { // Check for custom name from providerNodes (for compatible providers) @@ -112,7 +119,6 @@ export default function ModelSelectModal({ } } else if (isCustomProvider) { // Find connection object to get prefix synchronously without waiting for providerNodes fetch - const connection = activeProviders.find(p => p.provider === providerId); const matchedNode = providerNodes.find(node => node.id === providerId); const displayName = connection?.name || matchedNode?.name || providerInfo.name; const nodePrefix = connection?.providerSpecificData?.prefix || matchedNode?.prefix || providerId; @@ -125,7 +131,8 @@ export default function ModelSelectModal({ id: fullModel.replace(`${providerId}/`, ""), name: aliasName, value: `${nodePrefix}/${fullModel.replace(`${providerId}/`, "")}`, - })); + })) + .filter((model) => !disabledModels.has(model.id)); // Always show compatible providers that are connected, even with no aliases. // When no aliases exist, show a placeholder so users know it's available. @@ -160,10 +167,13 @@ export default function ModelSelectModal({ .map(([aliasName, fullModel]) => { const modelId = fullModel.replace(`${alias}/`, ""); return { id: modelId, name: aliasName, value: fullModel, isCustom: true }; - }); + }) + .filter((model) => !disabledModels.has(model.id)); const allModels = [ - ...hardcodedModels.map((m) => ({ id: m.id, name: m.name, value: `${alias}/${m.id}` })), + ...hardcodedModels + .filter((model) => !disabledModels.has(model.id)) + .map((m) => ({ id: m.id, name: m.name, value: `${alias}/${m.id}` })), ...customModels, ]; @@ -361,4 +371,3 @@ ModelSelectModal.propTypes = { title: PropTypes.string, modelAliases: PropTypes.object, }; - diff --git a/src/sse/handlers/chat.js b/src/sse/handlers/chat.js index b2e64a1e6..aaf7beeab 100644 --- a/src/sse/handlers/chat.js +++ b/src/sse/handlers/chat.js @@ -8,7 +8,7 @@ import { isValidApiKey, } from "../services/auth.js"; import { cacheClaudeHeaders } from "open-sse/utils/claudeHeaderCache.js"; -import { getSettings } from "@/lib/localDb"; +import { getSettings, getProviderConnections } from "@/lib/localDb"; import { getModelInfo, getComboModels } from "../services/model.js"; import { handleChatCore } from "open-sse/handlers/chatCore.js"; import { errorResponse, unavailableResponse } from "open-sse/utils/error.js"; @@ -138,6 +138,16 @@ async function handleSingleModelChat(body, modelStr, clientRawRequest = null, re const { provider, model } = modelInfo; + // Guard: reject disabled models before routing (provider-wide, covers aliases) + const activeConnections = await getProviderConnections({ provider, isActive: true }); + if (activeConnections.length > 0) { + const disabledModels = activeConnections[0].providerSpecificData?.disabledModels; + if (Array.isArray(disabledModels) && disabledModels.includes(model)) { + log.warn("CHAT", `Model ${model} is disabled for provider: ${provider}`); + return errorResponse(HTTP_STATUS.NOT_FOUND, `No active credentials for provider: ${provider}`); + } + } + // Log model routing (alias → actual model) if (modelStr !== `${provider}/${model}`) { log.info("ROUTING", `${modelStr} → ${provider}/${model}`);