diff --git a/package-lock.json b/package-lock.json index e7621361..845f2858 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "node-banana", - "version": "1.1.1", + "version": "1.1.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "node-banana", - "version": "1.1.1", + "version": "1.1.2", "dependencies": { "@ai-sdk/google": "^3.0.13", "@ai-sdk/react": "^3.0.51", diff --git a/src/app/api/generate/providers/gemini.ts b/src/app/api/generate/providers/gemini.ts index 6c333b9c..2f1b10fa 100644 --- a/src/app/api/generate/providers/gemini.ts +++ b/src/app/api/generate/providers/gemini.ts @@ -18,6 +18,25 @@ export const MODEL_MAP: Record = { "nano-banana-2": "gemini-3.1-flash-image-preview", }; +/** + * Convert a base64 data URL image to Gemini inlineData format + */ +function imageToInlineData( + requestId: string, + image: string, + label: string +): { inlineData: { mimeType: string; data: string } } { + if (image.includes("base64,")) { + const [header, data] = image.split("base64,"); + const mimeMatch = header.match(/data:([^;]+)/); + const mimeType = mimeMatch ? mimeMatch[1] : "image/png"; + console.log(`[API:${requestId}] Image ${label}: ${mimeType}, ${(data.length / 1024).toFixed(1)}KB`); + return { inlineData: { mimeType, data } }; + } + console.log(`[API:${requestId}] Image ${label}: raw, ${(image.length / 1024).toFixed(1)}KB`); + return { inlineData: { mimeType: "image/png", data: image } }; +} + /** * Generate image using Gemini API (legacy/default path) */ @@ -30,37 +49,34 @@ export async function generateWithGemini( aspectRatio?: string, resolution?: string, useGoogleSearch?: boolean, - useImageSearch?: boolean + useImageSearch?: boolean, + multimodalParts?: Array<{ type: string; value: string; name?: string }> ): Promise> { - console.log(`[API:${requestId}] Gemini generation - Model: ${model}, Images: ${images?.length || 0}, Prompt: ${prompt?.length || 0} chars`); - - // Extract base64 data and MIME types from data URLs - const imageData = (images || []).map((image, idx) => { - if (image.includes("base64,")) { - const [header, data] = image.split("base64,"); - // Extract MIME type from header (e.g., "data:image/png;" -> "image/png") - const mimeMatch = header.match(/data:([^;]+)/); - const mimeType = mimeMatch ? mimeMatch[1] : "image/png"; - console.log(`[API:${requestId}] Image ${idx + 1}: ${mimeType}, ${(data.length / 1024).toFixed(1)}KB`); - return { data, mimeType }; - } - console.log(`[API:${requestId}] Image ${idx + 1}: raw, ${(image.length / 1024).toFixed(1)}KB`); - return { data: image, mimeType: "image/png" }; - }); + console.log(`[API:${requestId}] Gemini generation - Model: ${model}, Images: ${images?.length || 0}, Prompt: ${prompt?.length || 0} chars, Parts: ${multimodalParts?.length || 0}`); // Initialize Gemini client const ai = new GoogleGenAI({ apiKey }); - // Build request parts array with prompt and all images - const requestParts: Array<{ text: string } | { inlineData: { mimeType: string; data: string } }> = [ - { text: prompt }, - ...imageData.map(({ data, mimeType }) => ({ - inlineData: { - mimeType, - data, - }, - })), - ]; + // Build request parts array — use multimodal parts if provided, otherwise legacy prompt+images + type GeminiPart = { text: string } | { inlineData: { mimeType: string; data: string } }; + let requestParts: GeminiPart[]; + + if (multimodalParts && multimodalParts.length > 0) { + // Build interleaved multimodal request from image variable parts + requestParts = multimodalParts.map((part) => { + if (part.type === "image" && part.value) { + return imageToInlineData(requestId, part.value, part.name || "var"); + } + return { text: part.value }; + }); + } else { + // Legacy: prompt text + all images appended + const imageData = (images || []).map((image, idx) => imageToInlineData(requestId, image, `${idx + 1}`)); + requestParts = [ + { text: prompt }, + ...imageData, + ]; + } // Build config object based on model capabilities const config: Record = { diff --git a/src/app/api/generate/route.ts b/src/app/api/generate/route.ts index 29d15a54..65624473 100644 --- a/src/app/api/generate/route.ts +++ b/src/app/api/generate/route.ts @@ -100,7 +100,8 @@ export async function POST(request: NextRequest) { parameters, dynamicInputs, mediaType, - } = body; + parts, + } = body as MultiProviderGenerateRequest & { parts?: Array<{ type: string; value: string; name?: string }> }; // Prompt is required unless: // - Provided via dynamicInputs @@ -524,7 +525,8 @@ export async function POST(request: NextRequest) { aspectRatio, resolution, useGoogleSearch, - useImageSearch + useImageSearch, + parts as Array<{ type: string; value: string; name?: string }> | undefined ); } catch (error) { // Extract error information diff --git a/src/app/api/llm/route.ts b/src/app/api/llm/route.ts index 260614ef..15f5b3da 100644 --- a/src/app/api/llm/route.ts +++ b/src/app/api/llm/route.ts @@ -5,6 +5,15 @@ import { logger } from "@/utils/logger"; export const maxDuration = 60; // 1 minute timeout +// Convert a base64 data URL image to Gemini inlineData format +function imageDataToInlinePart(img: string): { inlineData: { mimeType: string; data: string } } { + const matches = img.match(/^data:(.+?);base64,(.+)$/); + if (matches) { + return { inlineData: { mimeType: matches[1], data: matches[2] } }; + } + return { inlineData: { mimeType: "image/png", data: img } }; +} + // Generate a unique request ID for tracking function generateRequestId(): string { return `llm-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`; @@ -36,7 +45,8 @@ async function generateWithGoogle( maxTokens: number, images?: string[], requestId?: string, - userApiKey?: string | null + userApiKey?: string | null, + parts?: Array<{ type: string; value: string; name?: string }> ): Promise { // User-provided key takes precedence over env variable const apiKey = userApiKey || process.env.GEMINI_API_KEY; @@ -55,31 +65,24 @@ async function generateWithGoogle( maxTokens, imageCount: images?.length || 0, promptLength: prompt.length, + partsCount: parts?.length || 0, }); - // Build multimodal content if images are provided - let contents: string | Array<{ inlineData: { mimeType: string; data: string } } | { text: string }>; - if (images && images.length > 0) { + // Build multimodal content + type GeminiPart = { inlineData: { mimeType: string; data: string } } | { text: string }; + let contents: string | GeminiPart[]; + + if (parts && parts.length > 0) { + // Interleaved multimodal parts from image variable resolution + contents = parts.map((part): GeminiPart => { + if (part.type === "image" && part.value) { + return imageDataToInlinePart(part.value); + } + return { text: part.value }; + }); + } else if (images && images.length > 0) { contents = [ - ...images.map((img) => { - // Extract base64 data and mime type from data URL - const matches = img.match(/^data:(.+?);base64,(.+)$/); - if (matches) { - return { - inlineData: { - mimeType: matches[1], - data: matches[2], - }, - }; - } - // Fallback: assume PNG if no data URL prefix - return { - inlineData: { - mimeType: "image/png", - data: img, - }, - }; - }), + ...images.map((img) => imageDataToInlinePart(img)), { text: prompt }, ]; } else { @@ -298,14 +301,15 @@ export async function POST(request: NextRequest) { const openaiApiKey = request.headers.get("X-OpenAI-API-Key"); const anthropicApiKey = request.headers.get("X-Anthropic-API-Key"); - const body: LLMGenerateRequest = await request.json(); + const body = await request.json() as LLMGenerateRequest & { parts?: Array<{ type: string; value: string; name?: string }> }; const { prompt, images, provider, model, temperature = 0.7, - maxTokens = 1024 + maxTokens = 1024, + parts, } = body; logger.info('api.llm', 'LLM generation request received', { @@ -330,7 +334,7 @@ export async function POST(request: NextRequest) { let text: string; if (provider === "google") { - text = await generateWithGoogle(prompt, model, temperature, maxTokens, images, requestId, geminiApiKey); + text = await generateWithGoogle(prompt, model, temperature, maxTokens, images, requestId, geminiApiKey, parts); } else if (provider === "openai") { text = await generateWithOpenAI(prompt, model, temperature, maxTokens, images, requestId, openaiApiKey); } else if (provider === "anthropic") { diff --git a/src/components/ConnectionDropMenu.tsx b/src/components/ConnectionDropMenu.tsx index e2730b35..c9c11acb 100644 --- a/src/components/ConnectionDropMenu.tsx +++ b/src/components/ConnectionDropMenu.tsx @@ -42,6 +42,15 @@ const IMAGE_TARGET_OPTIONS: MenuOption[] = [ ), }, + { + type: "inpaint", + label: "Inpaint", + icon: ( + + + + ), + }, { type: "splitGrid", label: "Split Grid Node", @@ -109,6 +118,15 @@ const IMAGE_TARGET_OPTIONS: MenuOption[] = [ ), }, + { + type: "promptConstructor", + label: "Prompt Constructor", + icon: ( + + + + ), + }, ]; const TEXT_TARGET_OPTIONS: MenuOption[] = [ @@ -254,6 +272,15 @@ const IMAGE_SOURCE_OPTIONS: MenuOption[] = [ ), }, + { + type: "inpaint", + label: "Inpaint", + icon: ( + + + + ), + }, { type: "router", label: "Router", diff --git a/src/components/Header.tsx b/src/components/Header.tsx index baf1ebda..10a54143 100644 --- a/src/components/Header.tsx +++ b/src/components/Header.tsx @@ -68,6 +68,7 @@ export function Header() { isSaving, setWorkflowMetadata, saveToFile, + saveWorkflow, loadWorkflow, previousWorkflowSnapshot, revertToSnapshot, @@ -83,6 +84,7 @@ export function Header() { isSaving: state.isSaving, setWorkflowMetadata: state.setWorkflowMetadata, saveToFile: state.saveToFile, + saveWorkflow: state.saveWorkflow, loadWorkflow: state.loadWorkflow, previousWorkflowSnapshot: state.previousWorkflowSnapshot, revertToSnapshot: state.revertToSnapshot, @@ -94,6 +96,7 @@ export function Header() { const [showProjectModal, setShowProjectModal] = useState(false); const [projectModalMode, setProjectModalMode] = useState<"new" | "settings">("new"); const fileInputRef = useRef(null); + const uploadInputRef = useRef(null); const isProjectConfigured = !!workflowName; const canSave = !!(workflowId && workflowName && saveDirectoryPath); @@ -119,6 +122,35 @@ export function Header() { fileInputRef.current?.click(); }; + const handleDownloadWorkflow = () => { + saveWorkflow(workflowName || undefined); + }; + + const handleUploadWorkflow = () => { + uploadInputRef.current?.click(); + }; + + const handleUploadFileChange = (e: React.ChangeEvent) => { + const file = e.target.files?.[0]; + if (!file) return; + + const reader = new FileReader(); + reader.onload = async (event) => { + try { + const workflow = JSON.parse(event.target?.result as string) as WorkflowFile; + if (workflow.version && workflow.nodes && workflow.edges) { + await loadWorkflow(workflow); + } else { + alert("Invalid workflow file format"); + } + } catch { + alert("Failed to parse workflow file"); + } + }; + reader.readAsText(file); + e.target.value = ""; + }; + const handleFileChange = (e: React.ChangeEvent) => { const file = e.target.files?.[0]; if (!file) return; @@ -189,6 +221,49 @@ export function Header() { } }, [revertToSnapshot]); + const clientWorkflowButtons = ( +
+ + +
+ ); + const settingsButtons = (
{settingsButtons} + {clientWorkflowButtons} )} diff --git a/src/components/InpaintMaskModal.tsx b/src/components/InpaintMaskModal.tsx new file mode 100644 index 00000000..782d062b --- /dev/null +++ b/src/components/InpaintMaskModal.tsx @@ -0,0 +1,299 @@ +"use client"; + +import { useCallback, useEffect, useRef, useState } from "react"; +import { Stage, Layer, Image as KonvaImage, Line } from "react-konva"; +import Konva from "konva"; + +interface InpaintMaskModalProps { + isOpen: boolean; + sourceImage: string; + existingMask: string | null; + brushSize: number; + onClose: () => void; + onSave: (maskDataUrl: string, brushSize: number) => void; +} + +const BRUSH_SIZES = [10, 20, 40, 60, 100]; + +export function InpaintMaskModal({ + isOpen, + sourceImage, + existingMask, + brushSize: initialBrushSize, + onClose, + onSave, +}: InpaintMaskModalProps) { + const stageRef = useRef(null); + const containerRef = useRef(null); + const [image, setImage] = useState(null); + const [maskImage, setMaskImage] = useState(null); + const [stageSize, setStageSize] = useState({ width: 800, height: 600 }); + const [scale, setScale] = useState(1); + const isDrawingRef = useRef(false); + const [tool, setTool] = useState<"brush" | "eraser">("brush"); + const [brushSize, setBrushSize] = useState(initialBrushSize); + const [lines, setLines] = useState<{ points: number[]; stroke: string; strokeWidth: number }[]>([]); + const [cursorPos, setCursorPos] = useState<{ x: number; y: number } | null>(null); + + // Load source image + useEffect(() => { + if (!sourceImage) return; + const img = new window.Image(); + img.onload = () => { + setImage(img); + if (containerRef.current) { + const cw = containerRef.current.clientWidth - 40; + const ch = containerRef.current.clientHeight - 40; + const s = Math.min(cw / img.width, ch / img.height, 1); + setScale(s); + setStageSize({ width: Math.round(img.width * s), height: Math.round(img.height * s) }); + } + }; + img.src = sourceImage; + }, [sourceImage]); + + // Load existing mask + useEffect(() => { + if (!existingMask) { + setMaskImage(null); + setLines([]); + return; + } + const img = new window.Image(); + img.onload = () => setMaskImage(img); + img.src = existingMask; + }, [existingMask]); + + const getPointerPos = useCallback(() => { + const stage = stageRef.current; + if (!stage) return null; + const pos = stage.getPointerPosition(); + if (!pos) return null; + return { x: pos.x / scale, y: pos.y / scale }; + }, [scale]); + + const handleMouseDown = useCallback((e: Konva.KonvaEventObject) => { + // Only draw with left mouse button (button 0) or touch + if (e.evt instanceof MouseEvent && e.evt.button !== 0) return; + const pos = getPointerPos(); + if (!pos) return; + isDrawingRef.current = true; + const stroke = tool === "brush" ? "#ffffff" : "#000000"; + setLines((prev) => [...prev, { points: [pos.x, pos.y], stroke, strokeWidth: brushSize }]); + }, [getPointerPos, tool, brushSize]); + + const handleMouseMove = useCallback(() => { + const pos = getPointerPos(); + if (!pos) return; + setCursorPos({ x: pos.x * scale, y: pos.y * scale }); + if (!isDrawingRef.current) return; + setLines((prev) => { + const updated = [...prev]; + const last = updated[updated.length - 1]; + if (last) { + updated[updated.length - 1] = { ...last, points: [...last.points, pos.x, pos.y] }; + } + return updated; + }); + }, [getPointerPos, scale]); + + const handleMouseUp = useCallback(() => { + isDrawingRef.current = false; + }, []); + + const handleClear = useCallback(() => { + setLines([]); + setMaskImage(null); + }, []); + + const handleSave = useCallback(() => { + if (!image) return; + + // Render mask to a separate canvas (black background with white brush strokes) + const tempStage = new Konva.Stage({ + container: document.createElement("div"), + width: image.width, + height: image.height, + }); + const tempLayer = new Konva.Layer(); + tempStage.add(tempLayer); + + // Black background + tempLayer.add( + new Konva.Rect({ + x: 0, + y: 0, + width: image.width, + height: image.height, + fill: "#000000", + }) + ); + + // Draw existing mask if present + if (maskImage) { + tempLayer.add( + new Konva.Image({ + image: maskImage, + x: 0, + y: 0, + width: image.width, + height: image.height, + }) + ); + } + + // Draw brush strokes + for (const line of lines) { + tempLayer.add( + new Konva.Line({ + points: line.points, + stroke: line.stroke, + strokeWidth: line.strokeWidth, + lineCap: "round", + lineJoin: "round", + globalCompositeOperation: line.stroke === "#000000" ? "destination-out" : "source-over", + }) + ); + } + + tempLayer.draw(); + const maskDataUrl = tempStage.toDataURL({ mimeType: "image/png" }); + tempStage.destroy(); + + onSave(maskDataUrl, brushSize); + }, [image, maskImage, lines, brushSize, onSave]); + + // Keyboard shortcuts + useEffect(() => { + const handleKeyDown = (e: KeyboardEvent) => { + if (e.key === "Escape") onClose(); + if (e.key === "b") setTool("brush"); + if (e.key === "e") setTool("eraser"); + if (e.key === "[") setBrushSize((prev) => BRUSH_SIZES[Math.max(0, BRUSH_SIZES.indexOf(prev) - 1)] ?? prev); + if (e.key === "]") setBrushSize((prev) => BRUSH_SIZES[Math.min(BRUSH_SIZES.length - 1, BRUSH_SIZES.indexOf(prev) + 1)] ?? prev); + }; + window.addEventListener("keydown", handleKeyDown); + return () => window.removeEventListener("keydown", handleKeyDown); + }, [onClose]); + + if (!isOpen) return null; + + const hasStrokes = lines.length > 0 || maskImage !== null; + + return ( +
+
e.stopPropagation()} + > + {/* Toolbar */} +
+ Inpaint Mask +
+ + {/* Tool selection */} + + + +
+ + {/* Brush size */} + Size [ ]: + {BRUSH_SIZES.map((s) => ( + + ))} + +
+ + + + +
+ + {/* Canvas */} +
+ { isDrawingRef.current = false; setCursorPos(null); }} + onTouchStart={handleMouseDown} + onTouchMove={handleMouseMove} + onTouchEnd={handleMouseUp} + style={{ cursor: "none" }} + > + + {/* Source image at lower opacity */} + {image && } + + {/* Existing mask overlay */} + {maskImage && } + + {/* Drawn lines */} + {lines.map((line, i) => ( + + ))} + + + + {/* Custom cursor */} + {cursorPos && ( +
+ )} +
+ +
+ Hold mouse button to paint. White = area to regenerate. Shortcuts: B=brush, E=eraser, [/]=brush size, Esc=cancel +
+
+
+ ); +} diff --git a/src/components/SplitGridSettingsModal.tsx b/src/components/SplitGridSettingsModal.tsx index 04267b53..81af8bb8 100644 --- a/src/components/SplitGridSettingsModal.tsx +++ b/src/components/SplitGridSettingsModal.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useCallback } from "react"; +import { useState, useCallback, useEffect } from "react"; import { createPortal } from "react-dom"; import { useWorkflowStore } from "@/store/workflowStore"; import { SplitGridNodeData, AspectRatio, Resolution, ModelType } from "@/types"; @@ -12,15 +12,64 @@ interface SplitGridSettingsModalProps { } const LAYOUT_OPTIONS = [ + { rows: 1, cols: 2 }, + { rows: 1, cols: 3 }, + { rows: 1, cols: 4 }, + { rows: 2, cols: 1 }, { rows: 2, cols: 2 }, - { rows: 1, cols: 5 }, { rows: 2, cols: 3 }, { rows: 3, cols: 2 }, { rows: 2, cols: 4 }, { rows: 3, cols: 3 }, - { rows: 2, cols: 5 }, + { rows: 3, cols: 4 }, + { rows: 4, cols: 3 }, + { rows: 4, cols: 4 }, ] as const; +// Cell aspect ratio options for splitting +const CELL_ASPECT_RATIOS = [ + { label: "Auto", value: 0 }, + { label: "1:1", value: 1 }, + { label: "16:9", value: 16 / 9 }, + { label: "9:16", value: 9 / 16 }, + { label: "4:3", value: 4 / 3 }, + { label: "3:4", value: 3 / 4 }, + { label: "3:2", value: 3 / 2 }, + { label: "2:3", value: 2 / 3 }, + { label: "21:9", value: 21 / 9 }, +]; + +/** + * Given a source image aspect ratio and desired cell aspect ratio, + * find the best rows x cols that produce cells closest to the target. + */ +function findBestGrid( + sourceWidth: number, + sourceHeight: number, + cellAR: number, + maxDim: number = 10 +): { rows: number; cols: number } { + const sourceAR = sourceWidth / sourceHeight; + let bestRows = 1; + let bestCols = 1; + let bestError = Infinity; + + for (let r = 1; r <= maxDim; r++) { + for (let c = 1; c <= maxDim; c++) { + if (r === 1 && c === 1) continue; + const actualCellAR = (sourceAR * r) / c; + const error = Math.abs(actualCellAR / cellAR - 1); + if (error < bestError) { + bestError = error; + bestRows = r; + bestCols = c; + } + } + } + + return { rows: bestRows, cols: bestCols }; +} + const BASE_ASPECT_RATIOS: AspectRatio[] = ["1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"]; const EXTENDED_ASPECT_RATIOS: AspectRatio[] = ["1:1", "1:4", "1:8", "2:3", "3:2", "3:4", "4:1", "4:3", "4:5", "5:4", "8:1", "9:16", "16:9", "21:9"]; const RESOLUTIONS_PRO: Resolution[] = ["1K", "2K", "4K"]; @@ -32,8 +81,7 @@ const MODELS: { value: ModelType; label: string }[] = [ ]; const findLayoutIndex = (rows: number, cols: number): number => { - const idx = LAYOUT_OPTIONS.findIndex(l => l.rows === rows && l.cols === cols); - return idx >= 0 ? idx : 2; // default to 2x3 + return LAYOUT_OPTIONS.findIndex(l => l.rows === rows && l.cols === cols); }; export function SplitGridSettingsModal({ @@ -43,17 +91,41 @@ export function SplitGridSettingsModal({ }: SplitGridSettingsModalProps) { const { updateNodeData, addNode, onConnect, addEdgeWithType, getNodeById } = useWorkflowStore(); + const initialPresetIdx = findLayoutIndex(nodeData.gridRows, nodeData.gridCols); const [selectedLayoutIndex, setSelectedLayoutIndex] = useState( - findLayoutIndex(nodeData.gridRows, nodeData.gridCols) + initialPresetIdx >= 0 ? initialPresetIdx : -1 ); + const [customRows, setCustomRows] = useState(nodeData.gridRows); + const [customCols, setCustomCols] = useState(nodeData.gridCols); + const [cellAspectRatio, setCellAspectRatio] = useState(0); // 0 = auto + const [sourceDims, setSourceDims] = useState<{ width: number; height: number } | null>(null); const [defaultPrompt, setDefaultPrompt] = useState(nodeData.defaultPrompt); + + // Load source image dimensions + useEffect(() => { + if (!nodeData.sourceImage) return; + const img = new Image(); + img.onload = () => setSourceDims({ width: img.width, height: img.height }); + img.src = nodeData.sourceImage; + }, [nodeData.sourceImage]); + + // Auto-calculate grid when cell aspect ratio changes + const handleCellAspectRatioChange = useCallback((arValue: number) => { + setCellAspectRatio(arValue); + if (arValue === 0 || !sourceDims) return; + const best = findBestGrid(sourceDims.width, sourceDims.height, arValue); + setCustomRows(best.rows); + setCustomCols(best.cols); + setSelectedLayoutIndex(findLayoutIndex(best.rows, best.cols)); + }, [sourceDims]); const [aspectRatio, setAspectRatio] = useState(nodeData.generateSettings.aspectRatio); const [resolution, setResolution] = useState(nodeData.generateSettings.resolution); const [model, setModel] = useState(nodeData.generateSettings.model); const [useGoogleSearch, setUseGoogleSearch] = useState(nodeData.generateSettings.useGoogleSearch); const [useImageSearch, setUseImageSearch] = useState(nodeData.generateSettings.useImageSearch); - const { rows, cols } = LAYOUT_OPTIONS[selectedLayoutIndex]; + const rows = selectedLayoutIndex >= 0 ? LAYOUT_OPTIONS[selectedLayoutIndex].rows : customRows; + const cols = selectedLayoutIndex >= 0 ? LAYOUT_OPTIONS[selectedLayoutIndex].cols : customCols; const targetCount = rows * cols; const isNanoBananaPro = model === "nano-banana-pro" || model === "nano-banana-2"; const aspectRatios = model === "nano-banana-2" ? EXTENDED_ASPECT_RATIOS : BASE_ASPECT_RATIOS; @@ -200,15 +272,19 @@ export function SplitGridSettingsModal({ -
+
{LAYOUT_OPTIONS.map((layout, index) => { const count = layout.rows * layout.cols; const isSelected = selectedLayoutIndex === index; return (
{layout.rows}x{layout.cols}
-
{count}
); })}
-

- Grid will be split into {rows}x{cols} = {targetCount} images -

+ + {/* Cell aspect ratio selector */} +
+ Cell ratio: +
+ {CELL_ASPECT_RATIOS.map((ar) => ( + + ))} +
+ {cellAspectRatio !== 0 && !sourceDims && ( + Connect image for auto-grid + )} +
+ + {/* Custom rows/cols input */} +
+ Custom: +
+ { + const v = Math.max(1, Math.min(10, parseInt(e.target.value) || 1)); + setCustomRows(v); + setSelectedLayoutIndex(findLayoutIndex(v, customCols)); + }} + className="w-14 px-2 py-1 bg-neutral-900 border border-neutral-600 rounded text-neutral-100 text-sm text-center focus:outline-none focus:border-neutral-500" + /> + x + { + const v = Math.max(1, Math.min(10, parseInt(e.target.value) || 1)); + setCustomCols(v); + setSelectedLayoutIndex(findLayoutIndex(customRows, v)); + }} + className="w-14 px-2 py-1 bg-neutral-900 border border-neutral-600 rounded text-neutral-100 text-sm text-center focus:outline-none focus:border-neutral-500" + /> +
+ + = {targetCount} images + {sourceDims && ( + + (cell: {Math.round(sourceDims.width / cols)}x{Math.round(sourceDims.height / rows)}px) + + )} + +
{/* Default prompt */} diff --git a/src/components/WorkflowCanvas.tsx b/src/components/WorkflowCanvas.tsx index 209485de..7f8d030f 100644 --- a/src/components/WorkflowCanvas.tsx +++ b/src/components/WorkflowCanvas.tsx @@ -45,6 +45,7 @@ import { RouterNode, SwitchNode, ConditionalSwitchNode, + InpaintNode, } from "./nodes"; // Lazy-load GLBViewerNode to avoid bundling three.js for users who don't use 3D nodes @@ -98,6 +99,7 @@ const nodeTypes: NodeTypes = { switch: SwitchNode, conditionalSwitch: ConditionalSwitchNode, glbViewer: GLBViewerNode, + inpaint: InpaintNode, }; const edgeTypes: EdgeTypes = { @@ -144,7 +146,7 @@ const getNodeHandles = (nodeType: string): { inputs: string[]; outputs: string[] case "array": return { inputs: ["text"], outputs: ["text"] }; case "promptConstructor": - return { inputs: ["text"], outputs: ["text"] }; + return { inputs: ["text", "image"], outputs: ["text"] }; case "nanoBanana": return { inputs: ["image", "text"], outputs: ["image"] }; case "generateVideo": @@ -1473,6 +1475,7 @@ export function WorkflowCanvas() { switch: { width: 220, height: 120 }, conditionalSwitch: { width: 260, height: 180 }, glbViewer: { width: 360, height: 380 }, + inpaint: { width: 320, height: 340 }, }; const dims = defaultDimensions[nodeType]; addNode(nodeType, { x: centerX - dims.width / 2, y: centerY - dims.height / 2 }); @@ -1980,7 +1983,7 @@ export function WorkflowCanvas() { ? false : canvasNavigationSettings.panMode === "always" ? false - : isMacOS && !isModalOpen + : !isModalOpen } selectionKeyCode={ isModalOpen ? null @@ -1995,7 +1998,7 @@ export function WorkflowCanvas() { ? true : canvasNavigationSettings.panMode === "middleMouse" ? [2] - : !isMacOS + : [1, 2] } selectNodesOnDrag={false} nodeDragThreshold={5} @@ -2080,6 +2083,8 @@ export function WorkflowCanvas() { return "#06b6d4"; // cyan-500 (distinct from Router gray and Switch violet) case "glbViewer": return "#0ea5e9"; // sky-500 (3D viewport) + case "inpaint": + return "#f472b6"; // pink-400 (mask/edit) default: return "#94a3b8"; } diff --git a/src/components/__tests__/ConnectionDropMenu.test.tsx b/src/components/__tests__/ConnectionDropMenu.test.tsx index 5e4ae0ca..c4e5e5e8 100644 --- a/src/components/__tests__/ConnectionDropMenu.test.tsx +++ b/src/components/__tests__/ConnectionDropMenu.test.tsx @@ -218,7 +218,7 @@ describe("ConnectionDropMenu", () => { fireEvent.keyDown(document, { key: "ArrowUp" }); // Last item should now be highlighted - const lastButton = screen.getByText("Switch").closest("button"); + const lastButton = screen.getByText("Prompt Constructor").closest("button"); expect(lastButton).toHaveClass("bg-neutral-700"); }); diff --git a/src/components/__tests__/SplitGridSettingsModal.test.tsx b/src/components/__tests__/SplitGridSettingsModal.test.tsx index 055d43f2..7197d3ce 100644 --- a/src/components/__tests__/SplitGridSettingsModal.test.tsx +++ b/src/components/__tests__/SplitGridSettingsModal.test.tsx @@ -78,13 +78,18 @@ describe("SplitGridSettingsModal", () => { /> ); + expect(screen.getByText("1x2")).toBeInTheDocument(); + expect(screen.getByText("1x3")).toBeInTheDocument(); + expect(screen.getByText("1x4")).toBeInTheDocument(); + expect(screen.getByText("2x1")).toBeInTheDocument(); expect(screen.getByText("2x2")).toBeInTheDocument(); - expect(screen.getByText("1x5")).toBeInTheDocument(); expect(screen.getByText("2x3")).toBeInTheDocument(); expect(screen.getByText("3x2")).toBeInTheDocument(); expect(screen.getByText("2x4")).toBeInTheDocument(); expect(screen.getByText("3x3")).toBeInTheDocument(); - expect(screen.getByText("2x5")).toBeInTheDocument(); + expect(screen.getByText("3x4")).toBeInTheDocument(); + expect(screen.getByText("4x3")).toBeInTheDocument(); + expect(screen.getByText("4x4")).toBeInTheDocument(); }); it("should highlight selected layout", () => { @@ -116,8 +121,8 @@ describe("SplitGridSettingsModal", () => { const threeByThreeButton = buttons.find(btn => btn.textContent?.includes("3x3")); fireEvent.click(threeByThreeButton!); - // The grid description should update - expect(screen.getByText(/3x3 = 9 images/)).toBeInTheDocument(); + // The custom input area should show target count + expect(screen.getByText(/= 9 images/)).toBeInTheDocument(); }); it("should display grid dimensions description", () => { @@ -130,7 +135,7 @@ describe("SplitGridSettingsModal", () => { ); // Default is 2x3 - expect(screen.getByText(/2x3 = 6 images/)).toBeInTheDocument(); + expect(screen.getByText(/= 6 images/)).toBeInTheDocument(); }); it("should allow selecting 3x2 layout (6 images, portrait orientation)", () => { @@ -148,7 +153,7 @@ describe("SplitGridSettingsModal", () => { fireEvent.click(threeByTwoButton!); // Should show 3x2 = 6 images - expect(screen.getByText(/3x2 = 6 images/)).toBeInTheDocument(); + expect(screen.getByText(/= 6 images/)).toBeInTheDocument(); }); }); @@ -453,7 +458,7 @@ describe("SplitGridSettingsModal", () => { // The modal renders via createPortal to document.body, so use document.querySelectorAll const gridPreviews = document.querySelectorAll(".aspect-video"); - expect(gridPreviews.length).toBe(7); // 7 layout options + expect(gridPreviews.length).toBe(12); // 12 layout options }); }); @@ -480,7 +485,7 @@ describe("SplitGridSettingsModal", () => { ); // Check target count - expect(screen.getByText(/3x3 = 9 images/)).toBeInTheDocument(); + expect(screen.getByText(/= 9 images/)).toBeInTheDocument(); // Check prompt const textarea = screen.getByPlaceholderText(/Enter prompt that will be applied/); diff --git a/src/components/modals/PromptConstructorEditorModal.tsx b/src/components/modals/PromptConstructorEditorModal.tsx index 1a0ec22f..22eb1278 100644 --- a/src/components/modals/PromptConstructorEditorModal.tsx +++ b/src/components/modals/PromptConstructorEditorModal.tsx @@ -240,7 +240,7 @@ export const PromptConstructorEditorModal: React.FC 0 && (
{ e.preventDefault(); + e.stopPropagation(); handleAutocompleteSelect(variable.name); }} className={`w-full px-3 py-2 text-left text-[11px] flex flex-col gap-0.5 transition-colors ${ @@ -259,9 +260,9 @@ export const PromptConstructorEditorModal: React.FC -
@{variable.name}
+
@{variable.name}
- {variable.value || "(empty)"} + {variable.variableType === "image" ? "(image)" : variable.value || "(empty)"}
))} diff --git a/src/components/nodes/ImageInputNode.tsx b/src/components/nodes/ImageInputNode.tsx index 00a27203..4ad4ddac 100644 --- a/src/components/nodes/ImageInputNode.tsx +++ b/src/components/nodes/ImageInputNode.tsx @@ -1,6 +1,7 @@ "use client"; -import { useCallback, useRef } from "react"; +import { useCallback, useRef, useState } from "react"; +import { createPortal } from "react-dom"; import { Handle, Position, NodeProps, Node } from "@xyflow/react"; import { BaseNode } from "./BaseNode"; import { useCommentNavigation } from "@/hooks/useCommentNavigation"; @@ -17,6 +18,26 @@ export function ImageInputNode({ id, data, selected }: NodeProps state.updateNodeData); const fileInputRef = useRef(null); + // Variable naming state + const [showVarDialog, setShowVarDialog] = useState(false); + const [varNameInput, setVarNameInput] = useState(nodeData.variableName || ""); + + const handleSaveVariableName = useCallback(() => { + updateNodeData(id, { variableName: varNameInput || undefined }); + setShowVarDialog(false); + }, [id, varNameInput, updateNodeData]); + + const handleClearVariableName = useCallback(() => { + setVarNameInput(""); + updateNodeData(id, { variableName: undefined }); + setShowVarDialog(false); + }, [id, updateNodeData]); + + const handleVariableNameChange = useCallback((e: React.ChangeEvent) => { + const sanitized = e.target.value.replace(/[^a-zA-Z0-9_]/g, "").slice(0, 30); + setVarNameInput(sanitized); + }, []); + const handleFileChange = useCallback( (e: React.ChangeEvent) => { const file = e.target.files?.[0]; @@ -84,6 +105,7 @@ export function ImageInputNode({ id, data, selected }: NodeProps )} + {/* Variable name badge */} +
+ +
+
+ + {/* Variable Naming Dialog */} + {showVarDialog && createPortal( +
+
+

Set Image Variable Name

+

+ Reference this image as @name in prompts and PromptConstructor templates +

+
+ + { + if (e.key === "Enter" && varNameInput) { + handleSaveVariableName(); + } + }} + placeholder="e.g. photo, reference, style" + className="w-full px-3 py-2 text-sm text-neutral-100 bg-neutral-900 border border-neutral-700 rounded focus:outline-none focus:ring-1 focus:ring-emerald-500" + autoFocus + /> + {varNameInput && ( +
+ Preview: @{varNameInput} +
+ )} +
+
+ {nodeData.variableName && ( + + )} + + +
+
+
, + document.body + )} + ); } diff --git a/src/components/nodes/InpaintNode.tsx b/src/components/nodes/InpaintNode.tsx new file mode 100644 index 00000000..6317bff0 --- /dev/null +++ b/src/components/nodes/InpaintNode.tsx @@ -0,0 +1,191 @@ +"use client"; + +import React, { useCallback, useEffect, useState } from "react"; +import { Handle, Position, NodeProps, Node } from "@xyflow/react"; +import { BaseNode } from "./BaseNode"; +import { useWorkflowStore } from "@/store/workflowStore"; +import { InpaintNodeData, InpaintProvider } from "@/types"; +import { InpaintMaskModal } from "@/components/InpaintMaskModal"; +import { useAdaptiveImageSrc } from "@/hooks/useAdaptiveImageSrc"; +import { getConnectedInputsPure } from "@/store/utils/connectedInputs"; + +type InpaintNodeType = Node; + +const PROVIDER_OPTIONS: { value: InpaintProvider; label: string }[] = [ + { value: "gemini", label: "Gemini" }, + { value: "wavespeed", label: "WaveSpeed" }, +]; + +export function InpaintNode({ id, data, selected }: NodeProps) { + const updateNodeData = useWorkflowStore((state) => state.updateNodeData); + const regenerateNode = useWorkflowStore((state) => state.regenerateNode); + const [maskModalOpen, setMaskModalOpen] = useState(false); + + // Reactively read connected image from upstream nodes + const connectedImage = useWorkflowStore((state) => { + const { images } = getConnectedInputsPure(id, state.nodes, state.edges); + return images[0] ?? null; + }); + + const connectedText = useWorkflowStore((state) => { + const { text } = getConnectedInputsPure(id, state.nodes, state.edges); + return text; + }); + + // Sync connected inputs into node data so executor and mask modal can use them + useEffect(() => { + if (connectedImage && connectedImage !== data.inputImage) { + updateNodeData(id, { inputImage: connectedImage }); + } + }, [connectedImage, data.inputImage, id, updateNodeData]); + + useEffect(() => { + if (connectedText !== undefined && connectedText !== data.inputPrompt) { + updateNodeData(id, { inputPrompt: connectedText }); + } + }, [connectedText, data.inputPrompt, id, updateNodeData]); + + const adaptiveOutputImage = useAdaptiveImageSrc(data.outputImage, id); + const adaptiveInputImage = useAdaptiveImageSrc(connectedImage || data.inputImage, id); + + const handleGenerate = useCallback(() => { + regenerateNode(id); + }, [id, regenerateNode]); + + const handleMaskSave = useCallback( + (maskDataUrl: string, brushSize: number) => { + updateNodeData(id, { maskImage: maskDataUrl, maskBrushSize: brushSize }); + setMaskModalOpen(false); + }, + [id, updateNodeData] + ); + + const handleProviderChange = useCallback( + (provider: InpaintProvider) => { + updateNodeData(id, { inpaintProvider: provider }); + }, + [id, updateNodeData] + ); + + const displayImage = adaptiveOutputImage || adaptiveInputImage; + const hasMask = !!data.maskImage; + const hasInput = !!(connectedImage || data.inputImage); + const hasPrompt = !!(connectedText || data.inputPrompt); + // Image + mask required; prompt is optional (will use default if not connected) + const canGenerate = hasInput && hasMask; + + // Status hint for user + let statusHint: string | null = null; + if (!hasInput) statusHint = "Connect an image"; + else if (!hasMask) statusHint = "Draw a mask"; + else if (!hasPrompt) statusHint = "No prompt — will use default"; + + return ( + <> + + {/* Input handles */} + + + + {/* Output handle */} + + +
+ {/* Image preview */} +
+ {displayImage ? ( + Preview + ) : ( + Connect image + )} + + {/* Mask overlay indicator */} + {hasMask && hasInput && ( +
+ Masked +
+ )} + + {/* Status overlay */} + {data.status === "loading" && ( +
+
+
+ )} +
+ + {/* Provider selection */} +
+ {PROVIDER_OPTIONS.map((opt) => ( + + ))} +
+ + {/* Draw mask button */} + + + {/* Generate button */} + + + {/* Status hint */} + {statusHint && data.status !== "loading" && !data.error && ( +
+ {statusHint} +
+ )} + + {/* Error display */} + {data.error && ( +
+ {data.error} +
+ )} +
+ + + {/* Mask drawing modal */} + {maskModalOpen && (connectedImage || data.inputImage) && ( + setMaskModalOpen(false)} + onSave={handleMaskSave} + /> + )} + + ); +} diff --git a/src/components/nodes/PromptConstructorNode.tsx b/src/components/nodes/PromptConstructorNode.tsx index eaf7676d..7efdfcc2 100644 --- a/src/components/nodes/PromptConstructorNode.tsx +++ b/src/components/nodes/PromptConstructorNode.tsx @@ -5,7 +5,7 @@ import { Handle, Position, NodeProps, Node } from "@xyflow/react"; import { BaseNode } from "./BaseNode"; import { usePromptAutocomplete } from "@/hooks/usePromptAutocomplete"; import { useWorkflowStore } from "@/store/workflowStore"; -import { PromptConstructorNodeData, PromptNodeData, LLMGenerateNodeData, AvailableVariable } from "@/types"; +import { PromptConstructorNodeData, PromptNodeData, LLMGenerateNodeData, ImageInputNodeData, AvailableVariable } from "@/types"; import { parseVarTags } from "@/utils/parseVarTags"; type PromptConstructorNodeType = Node; @@ -30,12 +30,18 @@ export function PromptConstructorNode({ id, data, selected }: NodeProps tags) + // and image variables from connected image nodes const availableVariables = useMemo((): AvailableVariable[] => { const connectedTextNodes = edges .filter((e) => e.target === id && e.targetHandle === "text") .map((e) => nodes.find((n) => n.id === e.source)) .filter((n): n is typeof nodes[0] => n !== undefined); + const connectedImageNodes = edges + .filter((e) => e.target === id && e.targetHandle === "image") + .map((e) => nodes.find((n) => n.id === e.source)) + .filter((n): n is typeof nodes[0] => n !== undefined); + const vars: AvailableVariable[] = []; const usedNames = new Set(); @@ -48,13 +54,30 @@ export function PromptConstructorNode({ id, data, selected }: NodeProps tags from all connected text nodes + // Image variables from connected ImageInput nodes with variableName + connectedImageNodes.forEach((node) => { + if (node.type === "imageInput") { + const imgData = node.data as ImageInputNodeData; + if (imgData.variableName && !usedNames.has(imgData.variableName)) { + vars.push({ + name: imgData.variableName, + value: imgData.image ? "(image)" : "(no image)", + nodeId: node.id, + variableType: "image", + }); + usedNames.add(imgData.variableName); + } + } + }); + + // Parse inline tags from all connected text nodes connectedTextNodes.forEach((node) => { let text: string | null = null; if (node.type === "prompt") { @@ -74,6 +97,7 @@ export function PromptConstructorNode({ id, data, selected }: NodeProps { - const availableNames = new Set(availableVariables.map(v => v.name)); + const varMap = new Map(availableVariables.map(v => [v.name, v])); const pattern = /@(\w+)/g; const parts: ReactNode[] = []; let lastIndex = 0; @@ -157,9 +181,13 @@ export function PromptConstructorNode({ id, data, selected }: NodeProps lastIndex) { parts.push(localTemplate.slice(lastIndex, idx)); } - const isResolved = availableNames.has(match[1]); + const variable = varMap.get(match[1]); + let bgClass = "bg-red-400/50"; // unresolved + if (variable) { + bgClass = variable.variableType === "image" ? "bg-emerald-400/30" : "bg-blue-400/30"; + } parts.push( - {match[0]} + {match[0]} ); lastIndex = idx + match[0].length; } @@ -195,7 +223,16 @@ export function PromptConstructorNode({ id, data, selected }: NodeProps + + {/* Image input handle - for image variables */} + {/* Template textarea with highlight overlay for @variables */} @@ -226,7 +263,7 @@ export function PromptConstructorNode({ id, data, selected }: NodeProps 0 && (
{ e.preventDefault(); + e.stopPropagation(); handleAutocompleteSelect(variable.name); }} className={`w-full px-3 py-2 text-left text-[11px] flex flex-col gap-0.5 transition-colors ${ @@ -245,9 +283,9 @@ export function PromptConstructorNode({ id, data, selected }: NodeProps -
@{variable.name}
+
@{variable.name}
- {variable.value || "(empty)"} + {variable.variableType === "image" ? "(image)" : variable.value || "(empty)"}
))} diff --git a/src/components/nodes/index.ts b/src/components/nodes/index.ts index 27f0d61c..58957e6c 100644 --- a/src/components/nodes/index.ts +++ b/src/components/nodes/index.ts @@ -21,3 +21,4 @@ export { RouterNode } from "./RouterNode"; export { SwitchNode } from "./SwitchNode"; export { ConditionalSwitchNode } from "./ConditionalSwitchNode"; export { GroupNode } from "./GroupNode"; +export { InpaintNode } from "./InpaintNode"; diff --git a/src/hooks/useInlineParameters.ts b/src/hooks/useInlineParameters.ts index 0c50e5e8..0b26626e 100644 --- a/src/hooks/useInlineParameters.ts +++ b/src/hooks/useInlineParameters.ts @@ -8,15 +8,17 @@ const subscribers = new Set<() => void>(); // Get current value from localStorage function getSnapshot(): boolean { try { - return localStorage.getItem(INLINE_PARAMS_KEY) === "true"; + const value = localStorage.getItem(INLINE_PARAMS_KEY); + // Default to true (inline) when not explicitly set + return value === null ? true : value === "true"; } catch { - return false; + return true; } } -// Server-side snapshot (always false) +// Server-side snapshot (default to true to match client) function getServerSnapshot(): boolean { - return false; + return true; } // Subscribe to changes diff --git a/src/lib/quickstart/validation.ts b/src/lib/quickstart/validation.ts index d1e0a40c..72b99bd5 100644 --- a/src/lib/quickstart/validation.ts +++ b/src/lib/quickstart/validation.ts @@ -64,6 +64,7 @@ const DEFAULT_DIMENSIONS: Record = switch: { width: 220, height: 120 }, conditionalSwitch: { width: 260, height: 180 }, glbViewer: { width: 360, height: 380 }, + inpaint: { width: 320, height: 340 }, }; /** @@ -418,6 +419,19 @@ function createDefaultNodeData(type: NodeType): WorkflowNodeData { filename: null, capturedImage: null, }; + case "inpaint": + return { + inputImage: null, + maskImage: null, + inputPrompt: null, + outputImage: null, + inpaintProvider: "gemini", + maskBrushSize: 40, + status: "idle", + error: null, + imageHistory: [], + selectedHistoryIndex: 0, + }; } } diff --git a/src/store/execution/__tests__/generateVideoExecutor.test.ts b/src/store/execution/__tests__/generateVideoExecutor.test.ts index e7e54f7c..0225cf09 100644 --- a/src/store/execution/__tests__/generateVideoExecutor.test.ts +++ b/src/store/execution/__tests__/generateVideoExecutor.test.ts @@ -234,6 +234,9 @@ describe("executeGenerateVideo", () => { dynamicInputs: {}, easeCurve: null, }), + getEdges: vi.fn().mockReturnValue([ + { id: "e1", source: "prompt-1", target: "vid-1", targetHandle: "text" }, + ]), getFreshNode: vi.fn().mockReturnValue(node), }); mockFetch.mockResolvedValueOnce({ diff --git a/src/store/execution/__tests__/llmGenerateExecutor.test.ts b/src/store/execution/__tests__/llmGenerateExecutor.test.ts index cd5b303e..12be62a7 100644 --- a/src/store/execution/__tests__/llmGenerateExecutor.test.ts +++ b/src/store/execution/__tests__/llmGenerateExecutor.test.ts @@ -49,6 +49,7 @@ function makeCtx( videos: [], audio: [], text: "test llm prompt", + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -81,6 +82,7 @@ describe("executeLlmGenerate", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -144,6 +146,7 @@ describe("executeLlmGenerate", () => { videos: [], audio: [], text: "describe this", + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -224,9 +227,13 @@ describe("executeLlmGenerate", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), + getEdges: vi.fn().mockReturnValue([ + { id: "e1", source: "prompt-1", target: "llm-1", targetHandle: "text" }, + ]), }); mockFetch.mockResolvedValueOnce({ ok: true, diff --git a/src/store/execution/__tests__/nanoBananaExecutor.test.ts b/src/store/execution/__tests__/nanoBananaExecutor.test.ts index 8f4d94c2..5aaecdf6 100644 --- a/src/store/execution/__tests__/nanoBananaExecutor.test.ts +++ b/src/store/execution/__tests__/nanoBananaExecutor.test.ts @@ -59,6 +59,7 @@ function makeCtx( videos: [], audio: [], text: "test prompt", + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -97,6 +98,7 @@ describe("executeNanoBanana", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -251,6 +253,7 @@ describe("executeNanoBanana", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: { prompt: "dynamic prompt" }, easeCurve: null, }), @@ -278,6 +281,7 @@ describe("executeNanoBanana", () => { videos: [], audio: [], text: "with image", + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -304,9 +308,14 @@ describe("executeNanoBanana", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), + // Edges exist (upstream just not re-run) — fallback should be used + getEdges: vi.fn().mockReturnValue([ + { id: "e1", source: "prompt-1", target: "gen-1", targetHandle: "text" }, + ]), }); // Enable regenerate mode: fallback to stored inputs mockFetch.mockResolvedValueOnce({ diff --git a/src/store/execution/__tests__/simpleNodeExecutors.test.ts b/src/store/execution/__tests__/simpleNodeExecutors.test.ts index 242ed9bd..d79c7bdc 100644 --- a/src/store/execution/__tests__/simpleNodeExecutors.test.ts +++ b/src/store/execution/__tests__/simpleNodeExecutors.test.ts @@ -22,6 +22,7 @@ function makeCtx( videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -54,6 +55,7 @@ describe("executeAnnotation", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -72,6 +74,7 @@ describe("executeAnnotation", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -90,6 +93,7 @@ describe("executeAnnotation", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -113,6 +117,7 @@ describe("executeAnnotation", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -156,6 +161,7 @@ describe("executePrompt", () => { videos: [], audio: [], text: "new prompt", + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -215,6 +221,7 @@ describe("executePromptConstructor", () => { expect(ctx.updateNodeData).toHaveBeenCalledWith("pc", { outputText: "Hello World, welcome to Earth", unresolvedVars: [], + outputParts: null, }); }); @@ -241,6 +248,7 @@ describe("executePromptConstructor", () => { expect(ctx.updateNodeData).toHaveBeenCalledWith("pc", { outputText: "Hello World, welcome to @unknown", unresolvedVars: ["unknown"], + outputParts: null, }); }); @@ -263,6 +271,7 @@ describe("executePromptConstructor", () => { expect(ctx.updateNodeData).toHaveBeenCalledWith("pc", { outputText: "fresh @var", unresolvedVars: ["var"], + outputParts: null, }); }); }); @@ -276,6 +285,7 @@ describe("executeOutput", () => { videos: ["data:video/mp4;base64,abc"], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -298,6 +308,7 @@ describe("executeOutput", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -320,6 +331,7 @@ describe("executeOutput", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -342,6 +354,7 @@ describe("executeOutput", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -366,6 +379,7 @@ describe("executeOutputGallery", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -386,6 +400,7 @@ describe("executeOutputGallery", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -406,6 +421,7 @@ describe("executeOutputGallery", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -426,6 +442,7 @@ describe("executeImageCompare", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -447,6 +464,7 @@ describe("executeImageCompare", () => { videos: [], audio: [], text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -490,6 +508,7 @@ describe("executeGlbViewer", () => { audio: [], model3d: "https://example.com/model.glb", text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -519,6 +538,7 @@ describe("executeGlbViewer", () => { audio: [], model3d: "https://example.com/model.glb", text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -540,6 +560,7 @@ describe("executeGlbViewer", () => { audio: [], model3d: null, text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), @@ -562,6 +583,7 @@ describe("executeGlbViewer", () => { audio: [], model3d: "https://example.com/model.glb", text: null, + namedImages: {}, dynamicInputs: {}, easeCurve: null, }), diff --git a/src/store/execution/executeNode.ts b/src/store/execution/executeNode.ts index c9b3f143..5c08f425 100644 --- a/src/store/execution/executeNode.ts +++ b/src/store/execution/executeNode.ts @@ -24,6 +24,7 @@ import { executeLlmGenerate } from "./llmGenerateExecutor"; import { executeSplitGrid } from "./splitGridExecutor"; import { executeVideoStitch, executeEaseCurve, executeVideoTrim, executeVideoFrameGrab } from "./videoProcessingExecutors"; import { executeGenerateAudio } from "./generateAudioExecutor"; +import { executeInpaint } from "./inpaintExecutor"; export interface ExecuteNodeOptions { /** When true, executors that support it will fall back to stored inputs. */ @@ -107,5 +108,8 @@ export async function executeNode( case "videoFrameGrab": await executeVideoFrameGrab(ctx); break; + case "inpaint": + await executeInpaint(ctx); + break; } } diff --git a/src/store/execution/generate3dExecutor.ts b/src/store/execution/generate3dExecutor.ts index a6ad2658..5efc0784 100644 --- a/src/store/execution/generate3dExecutor.ts +++ b/src/store/execution/generate3dExecutor.ts @@ -23,6 +23,7 @@ export async function executeGenerate3D( getConnectedInputs, updateNodeData, getFreshNode, + getEdges, signal, providerSettings, addIncurredCost, @@ -43,8 +44,9 @@ export async function executeGenerate3D( let promptText: string | null; if (useStoredFallback) { - images = connectedImages.length > 0 ? connectedImages : nodeData.inputImages; - promptText = connectedText ?? nodeData.inputPrompt; + const hasIncomingEdges = getEdges().some((e) => e.target === node.id); + images = connectedImages.length > 0 ? connectedImages : (hasIncomingEdges ? nodeData.inputImages : []); + promptText = connectedText ?? (hasIncomingEdges ? nodeData.inputPrompt : null); } else { images = connectedImages; const promptFromDynamic = Array.isArray(dynamicInputs.prompt) diff --git a/src/store/execution/generateAudioExecutor.ts b/src/store/execution/generateAudioExecutor.ts index ee3d6351..777094b0 100644 --- a/src/store/execution/generateAudioExecutor.ts +++ b/src/store/execution/generateAudioExecutor.ts @@ -23,6 +23,7 @@ export async function executeGenerateAudio( getConnectedInputs, updateNodeData, getFreshNode, + getEdges, signal, providerSettings, addIncurredCost, @@ -43,7 +44,8 @@ export async function executeGenerateAudio( let text: string | null; if (useStoredFallback) { - text = connectedText ?? nodeData.inputPrompt; + const hasIncomingEdges = getEdges().some((e) => e.target === node.id); + text = connectedText ?? (hasIncomingEdges ? nodeData.inputPrompt : null); const hasPrompt = text || dynamicInputs.prompt; if (!hasPrompt) { updateNodeData(node.id, { diff --git a/src/store/execution/generateVideoExecutor.ts b/src/store/execution/generateVideoExecutor.ts index 4bb47222..3851a731 100644 --- a/src/store/execution/generateVideoExecutor.ts +++ b/src/store/execution/generateVideoExecutor.ts @@ -23,6 +23,7 @@ export async function executeGenerateVideo( getConnectedInputs, updateNodeData, getFreshNode, + getEdges, signal, providerSettings, addIncurredCost, @@ -44,8 +45,9 @@ export async function executeGenerateVideo( let text: string | null; if (useStoredFallback) { - images = connectedImages.length > 0 ? connectedImages : nodeData.inputImages; - text = connectedText ?? nodeData.inputPrompt; + const hasIncomingEdges = getEdges().some((e) => e.target === node.id); + images = connectedImages.length > 0 ? connectedImages : (hasIncomingEdges ? nodeData.inputImages : []); + text = connectedText ?? (hasIncomingEdges ? nodeData.inputPrompt : null); // Validate fallback inputs the same way as the regular path const hasPrompt = text || dynamicInputs.prompt || dynamicInputs.negative_prompt; if (!hasPrompt && images.length === 0) { diff --git a/src/store/execution/inpaintExecutor.ts b/src/store/execution/inpaintExecutor.ts new file mode 100644 index 00000000..0aa56514 --- /dev/null +++ b/src/store/execution/inpaintExecutor.ts @@ -0,0 +1,223 @@ +/** + * Inpaint Executor + * + * Handles masked image regeneration using either Gemini (pseudo-inpaint + * via multimodal prompt) or WaveSpeed (real mask-based inpainting). + */ + +import type { InpaintNodeData } from "@/types"; +import { buildGenerateHeaders } from "@/store/utils/buildApiHeaders"; +import type { NodeExecutionContext } from "./types"; + +export async function executeInpaint(ctx: NodeExecutionContext): Promise { + const { + node, + getConnectedInputs, + updateNodeData, + getFreshNode, + getEdges, + signal, + providerSettings, + addIncurredCost, + generationsPath, + trackSaveGeneration, + getNodes, + } = ctx; + + const freshNode = getFreshNode(node.id); + const nodeData = (freshNode?.data || node.data) as InpaintNodeData; + + const { images: connectedImages, text: connectedText } = getConnectedInputs(node.id); + + // Determine source image (connected or stored) + const hasIncomingEdges = getEdges().some((e) => e.target === node.id); + const sourceImage = connectedImages[0] ?? (hasIncomingEdges ? nodeData.inputImage : null); + const promptText = connectedText ?? nodeData.inputPrompt ?? "Regenerate the masked area naturally, matching the surrounding context"; + const maskImage = nodeData.maskImage; + + if (!sourceImage) { + updateNodeData(node.id, { status: "error", error: "Missing source image" }); + throw new Error("Missing source image"); + } + + if (!maskImage) { + updateNodeData(node.id, { status: "error", error: "No mask drawn — draw a mask first" }); + throw new Error("No mask drawn"); + } + + updateNodeData(node.id, { + inputImage: sourceImage, + inputPrompt: promptText, + status: "loading", + error: null, + }); + + const provider = nodeData.inpaintProvider || "gemini"; + + try { + let result: { success: boolean; image?: string; error?: string }; + + if (provider === "gemini") { + result = await inpaintWithGemini(sourceImage, maskImage, promptText, providerSettings, signal); + } else { + result = await inpaintWithWaveSpeed(sourceImage, maskImage, promptText, nodeData, providerSettings, signal); + } + + if (result.success && result.image) { + const timestamp = Date.now(); + const imageId = `${timestamp}`; + const newHistoryItem = { + id: imageId, + timestamp, + prompt: promptText, + aspectRatio: "1:1" as const, + model: "inpaint" as const, + }; + const updatedHistory = [newHistoryItem, ...(nodeData.imageHistory || [])].slice(0, 50); + + updateNodeData(node.id, { + outputImage: result.image, + status: "complete", + error: null, + imageHistory: updatedHistory, + selectedHistoryIndex: 0, + }); + + // Auto-save if configured + if (generationsPath) { + const savePromise = fetch("/api/save-generation", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + directoryPath: generationsPath, + image: result.image, + prompt: promptText, + imageId, + }), + }) + .then((res) => res.json()) + .then((saveResult) => { + if (saveResult.success && saveResult.imageId && saveResult.imageId !== imageId) { + const currentNode = getNodes().find((n) => n.id === node.id); + if (currentNode) { + const currentData = currentNode.data as InpaintNodeData; + const histCopy = [...(currentData.imageHistory || [])]; + const entryIndex = histCopy.findIndex((h) => h.id === imageId); + if (entryIndex !== -1) { + histCopy[entryIndex] = { ...histCopy[entryIndex], id: saveResult.imageId }; + updateNodeData(node.id, { imageHistory: histCopy }); + } + } + } + }) + .catch((err) => console.error("Failed to save inpaint generation:", err)); + + trackSaveGeneration(imageId, savePromise); + } + } else { + updateNodeData(node.id, { status: "error", error: result.error || "Inpainting failed" }); + throw new Error(result.error || "Inpainting failed"); + } + } catch (error) { + if (error instanceof DOMException && error.name === "AbortError") throw error; + const errorMessage = error instanceof Error ? error.message : "Inpainting failed"; + updateNodeData(node.id, { status: "error", error: errorMessage }); + throw new Error(errorMessage); + } +} + +/** + * Gemini pseudo-inpaint: sends source image + mask + prompt as multimodal parts. + * The model interprets the mask as guidance for which area to regenerate. + */ +async function inpaintWithGemini( + sourceImage: string, + maskImage: string, + prompt: string, + providerSettings: NodeExecutionContext["providerSettings"], + signal?: AbortSignal +): Promise<{ success: boolean; image?: string; error?: string }> { + const headers = buildGenerateHeaders("gemini", providerSettings); + + const inpaintPrompt = + `You are given two images: the first is the original image, the second is a black-and-white mask ` + + `where white areas indicate the region to edit. Edit ONLY the white masked area of the original image ` + + `according to this instruction: ${prompt}. Keep all non-masked areas exactly the same.`; + + const response = await fetch("/api/generate", { + method: "POST", + headers, + body: JSON.stringify({ + images: [sourceImage, maskImage], + prompt: inpaintPrompt, + model: "nano-banana-pro", + aspectRatio: "1:1", + }), + ...(signal ? { signal } : {}), + }); + + if (!response.ok) { + const errorText = await response.text(); + let errorMessage = `HTTP ${response.status}`; + try { + const errorJson = JSON.parse(errorText); + errorMessage = errorJson.error || errorMessage; + } catch { + if (errorText) errorMessage += ` - ${errorText.substring(0, 200)}`; + } + return { success: false, error: errorMessage }; + } + + const result = await response.json(); + return { success: result.success, image: result.image, error: result.error }; +} + +/** + * WaveSpeed real inpaint: sends source image + mask via the standard generation pipeline. + * WaveSpeed edit models accept `image` and `mask` parameters. + */ +async function inpaintWithWaveSpeed( + sourceImage: string, + maskImage: string, + prompt: string, + nodeData: InpaintNodeData, + providerSettings: NodeExecutionContext["providerSettings"], + signal?: AbortSignal +): Promise<{ success: boolean; image?: string; error?: string }> { + const headers = buildGenerateHeaders("wavespeed", providerSettings); + + const selectedModel = nodeData.selectedModel || { + provider: "wavespeed" as const, + modelId: "flux/inpaint", + displayName: "FLUX Inpaint", + }; + + const response = await fetch("/api/generate", { + method: "POST", + headers, + body: JSON.stringify({ + prompt, + images: [sourceImage], + selectedModel, + dynamicInputs: { + mask: maskImage, + }, + }), + ...(signal ? { signal } : {}), + }); + + if (!response.ok) { + const errorText = await response.text(); + let errorMessage = `HTTP ${response.status}`; + try { + const errorJson = JSON.parse(errorText); + errorMessage = errorJson.error || errorMessage; + } catch { + if (errorText) errorMessage += ` - ${errorText.substring(0, 200)}`; + } + return { success: false, error: errorMessage }; + } + + const result = await response.json(); + return { success: result.success, image: result.image, error: result.error }; +} diff --git a/src/store/execution/llmGenerateExecutor.ts b/src/store/execution/llmGenerateExecutor.ts index c833a993..14c5ef47 100644 --- a/src/store/execution/llmGenerateExecutor.ts +++ b/src/store/execution/llmGenerateExecutor.ts @@ -5,8 +5,9 @@ * Used by both executeWorkflow and regenerateNode. */ -import type { LLMGenerateNodeData } from "@/types"; +import type { LLMGenerateNodeData, PromptPart, PromptConstructorNodeData } from "@/types"; import { buildLlmHeaders } from "@/store/utils/buildApiHeaders"; +import { resolveImageVars, hasImageVarReferences } from "@/utils/resolveImageVars"; import type { NodeExecutionContext } from "./types"; export interface LlmGenerateOptions { @@ -22,6 +23,8 @@ export async function executeLlmGenerate( node, getConnectedInputs, updateNodeData, + getEdges, + getNodes, signal, providerSettings, } = ctx; @@ -36,8 +39,9 @@ export async function executeLlmGenerate( let text: string | null; if (useStoredFallback) { - images = inputs.images.length > 0 ? inputs.images : nodeData.inputImages; - text = inputs.text ?? nodeData.inputPrompt; + const hasIncomingEdges = getEdges().some((e) => e.target === node.id); + images = inputs.images.length > 0 ? inputs.images : (hasIncomingEdges ? nodeData.inputImages : []); + text = inputs.text ?? (hasIncomingEdges ? nodeData.inputPrompt : null); } else { images = inputs.images; text = inputs.text ?? nodeData.inputPrompt; @@ -58,6 +62,30 @@ export async function executeLlmGenerate( error: null, }); + // Check for multimodal parts from image variable references. + let parts: PromptPart[] | undefined; + + // 1. Try direct namedImages (imageInput connected directly to this node) + if (Object.keys(inputs.namedImages).length > 0 && hasImageVarReferences(text, inputs.namedImages)) { + parts = resolveImageVars(text, inputs.namedImages); + } + + // 2. Check upstream PromptConstructor for multimodal parts. + // outputParts are always fresh: during full workflow the PC executes first, + // and during regenerateNode the PC is re-executed before this node. + if (!parts) { + const upstreamPcNode = getEdges() + .filter((e) => e.target === node.id && (e.targetHandle === "text" || e.targetHandle?.startsWith("text"))) + .map((e) => getNodes().find((n) => n.id === e.source)) + .find((n) => n?.type === "promptConstructor"); + if (upstreamPcNode) { + const pcData = upstreamPcNode.data as PromptConstructorNodeData; + if (pcData.outputParts && pcData.outputParts.length > 0) { + parts = pcData.outputParts; + } + } + } + const headers = buildLlmHeaders(nodeData.provider, providerSettings); try { @@ -67,6 +95,7 @@ export async function executeLlmGenerate( body: JSON.stringify({ prompt: text, ...(images.length > 0 && { images }), + ...(parts && { parts }), provider: nodeData.provider, model: nodeData.model, temperature: nodeData.temperature, diff --git a/src/store/execution/nanoBananaExecutor.ts b/src/store/execution/nanoBananaExecutor.ts index bd142993..d34508b6 100644 --- a/src/store/execution/nanoBananaExecutor.ts +++ b/src/store/execution/nanoBananaExecutor.ts @@ -7,9 +7,12 @@ import type { NanoBananaNodeData, + PromptPart, + PromptConstructorNodeData, } from "@/types"; import { calculateGenerationCost } from "@/utils/costCalculator"; import { buildGenerateHeaders } from "@/store/utils/buildApiHeaders"; +import { resolveImageVars, hasImageVarReferences } from "@/utils/resolveImageVars"; import type { NodeExecutionContext } from "./types"; export interface NanoBananaOptions { @@ -40,7 +43,7 @@ export async function executeNanoBanana( const { useStoredFallback = false } = options; - const { images: connectedImages, text: connectedText, dynamicInputs } = getConnectedInputs(node.id); + const { images: connectedImages, text: connectedText, namedImages, dynamicInputs } = getConnectedInputs(node.id); // Get fresh node data from store const freshNode = getFreshNode(node.id); @@ -51,8 +54,11 @@ export async function executeNanoBanana( let promptText: string | null; if (useStoredFallback) { - images = connectedImages.length > 0 ? connectedImages : nodeData.inputImages; - promptText = connectedText ?? nodeData.inputPrompt; + // Only fall back to stored values if the node still has incoming edges. + // If all edges are disconnected, stored data is stale and should not be used. + const hasIncomingEdges = getEdges().some((e) => e.target === node.id); + images = connectedImages.length > 0 ? connectedImages : (hasIncomingEdges ? nodeData.inputImages : []); + promptText = connectedText ?? (hasIncomingEdges ? nodeData.inputPrompt : null); } else { images = connectedImages; // For dynamic inputs, check if we have at least a prompt @@ -94,7 +100,32 @@ export async function executeNanoBanana( const sanitizedDynamicInputs = { ...dynamicInputs }; delete sanitizedDynamicInputs.prompt; - const requestPayload = { + // Check for multimodal parts from image variable references. + let parts: PromptPart[] | undefined; + + // 1. Try direct namedImages (imageInput connected directly to this node) + if (Object.keys(namedImages).length > 0 && promptText && hasImageVarReferences(promptText, namedImages)) { + parts = resolveImageVars(promptText, namedImages); + } + + // 2. Check upstream PromptConstructor for multimodal parts. + // outputParts are always fresh here: during full workflow execution the PC + // executes first (topological order), and during regenerateNode the PC is + // re-executed before this node. + if (!parts) { + const upstreamPcNode = getEdges() + .filter((e) => e.target === node.id && (e.targetHandle === "text" || e.targetHandle?.startsWith("text"))) + .map((e) => getNodes().find((n) => n.id === e.source)) + .find((n) => n?.type === "promptConstructor"); + if (upstreamPcNode) { + const pcData = upstreamPcNode.data as PromptConstructorNodeData; + if (pcData.outputParts && pcData.outputParts.length > 0) { + parts = pcData.outputParts; + } + } + } + + const requestPayload: Record = { images, prompt: promptText, aspectRatio: nodeData.aspectRatio, @@ -105,6 +136,7 @@ export async function executeNanoBanana( selectedModel: nodeData.selectedModel, parameters: nodeData.parameters, dynamicInputs: sanitizedDynamicInputs, + ...(parts && { parts }), }; // Final guard: assert that prompt is a string before sending to API diff --git a/src/store/execution/simpleNodeExecutors.ts b/src/store/execution/simpleNodeExecutors.ts index 8adf593a..ae43a9c7 100644 --- a/src/store/execution/simpleNodeExecutors.ts +++ b/src/store/execution/simpleNodeExecutors.ts @@ -13,13 +13,16 @@ import type { PromptConstructorNodeData, PromptNodeData, LLMGenerateNodeData, + ImageInputNodeData, OutputNodeData, OutputGalleryNodeData, + PromptPart, WorkflowNode, } from "@/types"; import type { NodeExecutionContext } from "./types"; import { parseTextToArray } from "@/utils/arrayParser"; import { parseVarTags } from "@/utils/parseVarTags"; +import { resolveImageVars } from "@/utils/resolveImageVars"; /** * Annotation node: receives upstream image as source, passes through if no annotations. @@ -115,6 +118,12 @@ export async function executePromptConstructor(ctx: NodeExecutionContext): Promi .map((e) => nodes.find((n) => n.id === e.source)) .filter((n): n is WorkflowNode => n !== undefined); + // Find all connected image nodes + const connectedImageNodes = edges + .filter((e) => e.target === node.id && e.targetHandle === "image") + .map((e) => nodes.find((n) => n.id === e.source)) + .filter((n): n is WorkflowNode => n !== undefined); + // Build variable map: named variables from Prompt nodes take precedence const variableMap: Record = {}; connectedTextNodes.forEach((srcNode) => { @@ -126,6 +135,17 @@ export async function executePromptConstructor(ctx: NodeExecutionContext): Promi } }); + // Build named images map from connected ImageInput nodes with variableName + const namedImages: Record = {}; + connectedImageNodes.forEach((srcNode) => { + if (srcNode.type === "imageInput") { + const imgData = srcNode.data as ImageInputNodeData; + if (imgData.variableName && imgData.image) { + namedImages[imgData.variableName] = imgData.image; + } + } + }); + // Parse inline tags from all connected text nodes connectedTextNodes.forEach((srcNode) => { let text: string | null = null; @@ -153,21 +173,30 @@ export async function executePromptConstructor(ctx: NodeExecutionContext): Promi const unresolvedVars: string[] = []; let resolvedText = template; - // Replace @variables with values or track unresolved + // Replace text @variables with values or track unresolved + // (image @variables are left as-is in outputText — they're resolved in outputParts) const matches = template.matchAll(varPattern); for (const match of matches) { const varName = match[1]; if (variableMap[varName] !== undefined) { resolvedText = resolvedText.replaceAll(`@${varName}`, variableMap[varName]); - } else { + } else if (!namedImages[varName]) { + // Only unresolved if not an image variable either if (!unresolvedVars.includes(varName)) { unresolvedVars.push(varName); } } } + // Build multimodal outputParts if there are image variables referenced in the template + let outputParts: PromptPart[] | null = null; + if (Object.keys(namedImages).length > 0) { + outputParts = resolveImageVars(resolvedText, namedImages); + } + updateNodeData(node.id, { outputText: resolvedText, + outputParts, unresolvedVars, }); } catch (err) { diff --git a/src/store/utils/__tests__/connectedInputs.test.ts b/src/store/utils/__tests__/connectedInputs.test.ts index 37da703a..8ffb9f34 100644 --- a/src/store/utils/__tests__/connectedInputs.test.ts +++ b/src/store/utils/__tests__/connectedInputs.test.ts @@ -24,6 +24,7 @@ describe("getConnectedInputsPure", () => { expect(result.videos).toEqual([]); expect(result.audio).toEqual([]); expect(result.text).toBeNull(); + expect(result.namedImages).toEqual({}); expect(result.dynamicInputs).toEqual({}); expect(result.easeCurve).toBeNull(); }); diff --git a/src/store/utils/connectedInputs.ts b/src/store/utils/connectedInputs.ts index a848a380..b64160aa 100644 --- a/src/store/utils/connectedInputs.ts +++ b/src/store/utils/connectedInputs.ts @@ -26,6 +26,7 @@ import { GLBViewerNodeData, SwitchNodeData, ConditionalSwitchNodeData, + InpaintNodeData, MatchMode, } from "@/types"; @@ -38,6 +39,7 @@ export interface ConnectedInputs { audio: string[]; model3d: string | null; text: string | null; + namedImages: Record; // variableName → base64 image data URL dynamicInputs: Record; easeCurve: { bezierHandles: [number, number, number, number]; easingPreset: string | null; outputDuration: number } | null; } @@ -115,6 +117,8 @@ function getSourceOutput( return { type: "image", value: (sourceNode.data as VideoFrameGrabNodeData).outputImage }; } else if (sourceNode.type === "glbViewer") { return { type: "image", value: (sourceNode.data as GLBViewerNodeData).capturedImage }; + } else if (sourceNode.type === "inpaint") { + return { type: "image", value: (sourceNode.data as InpaintNodeData).outputImage }; } return { type: "image", value: null }; } @@ -131,13 +135,14 @@ export function getConnectedInputsPure( dimmedNodeIds?: Set ): ConnectedInputs { const _visited = visited || new Set(); - if (_visited.has(nodeId)) return { images: [], videos: [], audio: [], model3d: null, text: null, dynamicInputs: {}, easeCurve: null }; + if (_visited.has(nodeId)) return { images: [], videos: [], audio: [], model3d: null, text: null, namedImages: {}, dynamicInputs: {}, easeCurve: null }; _visited.add(nodeId); const images: string[] = []; const videos: string[] = []; const audio: string[] = []; let model3d: string | null = null; let text: string | null = null; + const namedImages: Record = {}; const dynamicInputs: Record = {}; let easeCurve: ConnectedInputs["easeCurve"] = null; @@ -190,6 +195,7 @@ export function getConnectedInputsPure( if (edgeType === "image" || (!edgeType && isImageHandle(edge.sourceHandle))) { images.push(...routerInputs.images); + Object.assign(namedImages, routerInputs.namedImages); } else if (edgeType === "text" || (!edgeType && isTextHandle(edge.sourceHandle))) { if (routerInputs.text) text = routerInputs.text; } else if (edgeType === "video") { @@ -224,6 +230,7 @@ export function getConnectedInputsPure( if (edgeType === "image") { images.push(...switchInputs.images); + Object.assign(namedImages, switchInputs.namedImages); } else if (edgeType === "text") { if (switchInputs.text) text = switchInputs.text; } else if (edgeType === "video") { @@ -289,6 +296,14 @@ export function getConnectedInputsPure( } } + // Track named image variables from imageInput nodes + if (type === "image" && sourceNode.type === "imageInput") { + const imgData = sourceNode.data as ImageInputNodeData; + if (imgData.variableName && value) { + namedImages[imgData.variableName] = value; + } + } + // Route to typed arrays based on source output type if (type === "3d") { model3d = value; @@ -323,7 +338,7 @@ export function getConnectedInputsPure( } } - return { images, videos, audio, model3d, text, dynamicInputs, easeCurve }; + return { images, videos, audio, model3d, text, namedImages, dynamicInputs, easeCurve }; } /** diff --git a/src/store/utils/nodeDefaults.ts b/src/store/utils/nodeDefaults.ts index fd450fc6..6b3ffad6 100644 --- a/src/store/utils/nodeDefaults.ts +++ b/src/store/utils/nodeDefaults.ts @@ -23,6 +23,7 @@ import { SwitchNodeData, ConditionalSwitchNodeData, GLBViewerNodeData, + InpaintNodeData, WorkflowNodeData, GroupColor, SelectedModel, @@ -58,6 +59,7 @@ export const defaultNodeDimensions: Record { return { template: "", outputText: null, + outputParts: null, unresolvedVars: [], } as PromptConstructorNodeData; case "nanoBanana": { @@ -322,5 +325,18 @@ export const createDefaultNodeData = (type: NodeType): WorkflowNodeData => { filename: null, capturedImage: null, } as GLBViewerNodeData; + case "inpaint": + return { + inputImage: null, + maskImage: null, + inputPrompt: null, + outputImage: null, + inpaintProvider: "gemini", + maskBrushSize: 40, + status: "idle", + error: null, + imageHistory: [], + selectedHistoryIndex: 0, + } as InpaintNodeData; } }; diff --git a/src/store/workflowStore.ts b/src/store/workflowStore.ts index 3391afde..5a746804 100644 --- a/src/store/workflowStore.ts +++ b/src/store/workflowStore.ts @@ -59,7 +59,7 @@ import { chunk, clearNodeImageRefs, } from "./utils/executionUtils"; -import { getConnectedInputsPure, validateWorkflowPure } from "./utils/connectedInputs"; +import { getConnectedInputsPure, validateWorkflowPure, ConnectedInputs } from "./utils/connectedInputs"; import { evaluateRule } from "./utils/ruleEvaluation"; import { computeDimmedNodes } from "./utils/dimmingUtils"; import { @@ -95,7 +95,7 @@ export { CONCURRENCY_SETTINGS_KEY } from "./utils/executionUtils"; async function evaluateAndExecuteConditionalSwitch( node: WorkflowNode, executionCtx: NodeExecutionContext, - getConnectedInputs: (nodeId: string) => { text: string | null; images: string[]; videos: string[]; audio: string[]; model3d: string | null; dynamicInputs: Record; easeCurve: { bezierHandles: [number, number, number, number]; easingPreset: string | null; outputDuration: number } | null }, + getConnectedInputs: (nodeId: string) => ConnectedInputs, updateNodeData: (nodeId: string, data: Partial) => void, ): Promise { const condInputs = getConnectedInputs(node.id); @@ -269,7 +269,7 @@ interface WorkflowStore { // Helpers getNodeById: (id: string) => WorkflowNode | undefined; - getConnectedInputs: (nodeId: string) => { images: string[]; videos: string[]; audio: string[]; model3d: string | null; text: string | null; dynamicInputs: Record; easeCurve: { bezierHandles: [number, number, number, number]; easingPreset: string | null; outputDuration: number } | null }; + getConnectedInputs: (nodeId: string) => ConnectedInputs; validateWorkflow: () => { valid: boolean; errors: string[] }; // Global Image History @@ -418,15 +418,20 @@ export { GROUP_COLORS } from "./utils/nodeDefaults"; /** Node types whose output carries image data */ const IMAGE_SOURCE_NODE_TYPES = new Set([ - "imageInput", "annotation", "nanoBanana", "glbViewer", "videoFrameGrab", + "imageInput", "annotation", "nanoBanana", "glbViewer", "videoFrameGrab", "inpaint", +]); + +const TEXT_SOURCE_NODE_TYPES = new Set([ + "prompt", "promptConstructor", "llmGenerate", "array", ]); /** - * After edges are removed, clear inputImages on any target node that no longer - * has an image-source edge. Prevents stale images from being sent to the API - * when useStoredFallback picks up old node data. + * After edges are removed, clear stale cached input data on target nodes. + * Clears inputImages when no image-source edges remain, inputPrompt when no + * text-source edges remain, and outputParts on PromptConstructor nodes when + * their image inputs are removed. */ -function clearStaleInputImages( +function clearStaleInputData( removedEdges: WorkflowEdge[], get: () => WorkflowStore ): void { @@ -435,14 +440,42 @@ function clearStaleInputImages( const targetIds = new Set(removedEdges.map((e) => e.target)); for (const targetId of targetIds) { const node = nodes.find((n) => n.id === targetId); - if (!node || !("inputImages" in (node.data as Record))) continue; - const hasRemainingImageSource = edges.some((e) => { - if (e.target !== targetId) return false; - const src = nodes.find((n) => n.id === e.source); - return src ? IMAGE_SOURCE_NODE_TYPES.has(src.type ?? "") : false; - }); - if (!hasRemainingImageSource) { - updateNodeData(targetId, { inputImages: [] }); + if (!node) continue; + const nodeData = node.data as Record; + + // Clear inputImages if no image-source edges remain + if ("inputImages" in nodeData) { + const hasRemainingImageSource = edges.some((e) => { + if (e.target !== targetId) return false; + const src = nodes.find((n) => n.id === e.source); + return src ? IMAGE_SOURCE_NODE_TYPES.has(src.type ?? "") : false; + }); + if (!hasRemainingImageSource) { + updateNodeData(targetId, { inputImages: [] }); + } + } + + // Clear inputPrompt if no text-source edges remain + if ("inputPrompt" in nodeData) { + const hasRemainingTextSource = edges.some((e) => { + if (e.target !== targetId) return false; + const src = nodes.find((n) => n.id === e.source); + return src ? TEXT_SOURCE_NODE_TYPES.has(src.type ?? "") : false; + }); + if (!hasRemainingTextSource) { + updateNodeData(targetId, { inputPrompt: null }); + } + } + + // Clear outputParts on PromptConstructor when image inputs are removed + if (node.type === "promptConstructor" && "outputParts" in nodeData) { + const hasRemainingImageInput = edges.some((e) => { + if (e.target !== targetId) return false; + return e.targetHandle === "image"; + }); + if (!hasRemainingImageInput) { + updateNodeData(targetId, { outputParts: null }); + } } } } @@ -636,7 +669,7 @@ const workflowStoreImpl: StateCreator = (set, get) => ({ })); if (hasRemoveChange) { - clearStaleInputImages(removedEdges, get); + clearStaleInputData(removedEdges, get); get().incrementManualChangeCount(); } @@ -686,7 +719,7 @@ const workflowStoreImpl: StateCreator = (set, get) => ({ edges: state.edges.filter((edge) => edge.id !== edgeId), hasUnsavedChanges: true, })); - if (removedEdge) clearStaleInputImages([removedEdge], get); + if (removedEdge) clearStaleInputData([removedEdge], get); get().incrementManualChangeCount(); }, @@ -1275,6 +1308,24 @@ const workflowStoreImpl: StateCreator = (set, get) => ({ }); try { + // Re-execute lightweight upstream nodes (prompt, promptConstructor) so their + // outputs are fresh. This ensures multimodal parts, resolved text, etc. are + // up-to-date without needing complex re-resolution logic in each executor. + const { edges: regenEdges, nodes: regenNodes } = get(); + const upstreamEdges = regenEdges.filter((e) => e.target === nodeId); + for (const edge of upstreamEdges) { + const srcNode = regenNodes.find((n) => n.id === edge.source); + if (!srcNode) continue; + if (srcNode.type === "prompt" || srcNode.type === "promptConstructor") { + const srcCtx = get()._buildExecutionContext(srcNode); + if (srcNode.type === "prompt") { + await executePrompt(srcCtx); + } else { + await executePromptConstructor(srcCtx); + } + } + } + const executionCtx = get()._buildExecutionContext(node); const regenOptions = { useStoredFallback: true }; diff --git a/src/types/nodes.ts b/src/types/nodes.ts index 963c41f7..379a4894 100644 --- a/src/types/nodes.ts +++ b/src/types/nodes.ts @@ -45,7 +45,8 @@ export type NodeType = | "switch" | "conditionalSwitch" | "generate3d" - | "glbViewer"; + | "glbViewer" + | "inpaint"; /** * Node execution status @@ -60,6 +61,7 @@ export interface ImageInputNodeData extends BaseNodeData { imageRef?: string; // External image reference for storage optimization filename: string | null; dimensions: { width: number; height: number } | null; + variableName?: string; // Optional variable name for use in PromptConstructor templates (image variable) } /** @@ -104,6 +106,7 @@ export interface ArrayNodeData extends BaseNodeData { export interface PromptConstructorNodeData extends BaseNodeData { template: string; outputText: string | null; + outputParts: PromptPart[] | null; // Multimodal parts with interleaved text and image references unresolvedVars: string[]; } @@ -114,8 +117,16 @@ export interface AvailableVariable { name: string; value: string; nodeId: string; + variableType: "text" | "image"; // Whether this variable holds text or image data } +/** + * A single part of a multimodal prompt (text or image reference) + */ +export type PromptPart = + | { type: "text"; value: string } + | { type: "image"; name: string; value: string }; // value is base64 data URL + /** * Image history item for tracking generated images */ @@ -438,6 +449,27 @@ export interface GLBViewerNodeData extends BaseNodeData { capturedImage: string | null; // Base64 PNG snapshot of the 3D viewport } +/** + * Inpaint node - masked image regeneration + */ +export type InpaintProvider = "gemini" | "wavespeed"; + +export interface InpaintNodeData extends BaseNodeData { + inputImage: string | null; // Source image (from connection or stored) + inputImageRef?: string; + maskImage: string | null; // Black/white mask (white = area to regenerate) + inputPrompt: string | null; // What to generate in the masked area + outputImage: string | null; // Result image + outputImageRef?: string; + inpaintProvider: InpaintProvider; // Which provider to use + selectedModel?: SelectedModel; // WaveSpeed model selection + maskBrushSize: number; // Last used brush size + status: NodeStatus; + error: string | null; + imageHistory: CarouselImageItem[]; + selectedHistoryIndex: number; +} + /** * Union of all node data types */ @@ -464,7 +496,8 @@ export type WorkflowNodeData = | RouterNodeData | SwitchNodeData | ConditionalSwitchNodeData - | GLBViewerNodeData; + | GLBViewerNodeData + | InpaintNodeData; /** * Workflow node with typed data (extended with optional groupId) diff --git a/src/utils/resolveImageVars.ts b/src/utils/resolveImageVars.ts new file mode 100644 index 00000000..1ff6d0b3 --- /dev/null +++ b/src/utils/resolveImageVars.ts @@ -0,0 +1,77 @@ +/** + * Resolve image variable references (@varName) in prompt text. + * + * Splits prompt text around @varName references that correspond to named images, + * producing an interleaved array of text and image parts for multimodal API requests. + * + * Text @variables are left as-is (they should already be resolved by the caller). + */ + +import type { PromptPart } from "@/types"; + +/** + * Check if a prompt text contains any @varName references that match named images. + */ +export function hasImageVarReferences( + text: string, + namedImages: Record +): boolean { + const pattern = /@(\w+)/g; + let match; + while ((match = pattern.exec(text)) !== null) { + if (namedImages[match[1]]) return true; + } + return false; +} + +/** + * Split prompt text into multimodal parts, replacing @varName with image parts + * where varName matches a key in namedImages. + * + * Non-image @variables are kept as literal text (they should already be resolved). + * + * @returns Array of PromptPart with interleaved text and image entries + */ +export function resolveImageVars( + text: string, + namedImages: Record +): PromptPart[] { + const pattern = /@(\w+)/g; + const parts: PromptPart[] = []; + let lastIndex = 0; + let match; + + while ((match = pattern.exec(text)) !== null) { + const varName = match[1]; + const imageValue = namedImages[varName]; + + if (!imageValue) continue; // Not an image variable, skip + + // Add text before this image reference + if (match.index > lastIndex) { + const textBefore = text.slice(lastIndex, match.index).trim(); + if (textBefore) { + parts.push({ type: "text", value: textBefore }); + } + } + + // Add the image part + parts.push({ type: "image", name: varName, value: imageValue }); + lastIndex = match.index + match[0].length; + } + + // Add remaining text after last image reference + if (lastIndex < text.length) { + const remaining = text.slice(lastIndex).trim(); + if (remaining) { + parts.push({ type: "text", value: remaining }); + } + } + + // If no image vars were found, return single text part + if (parts.length === 0) { + return [{ type: "text", value: text }]; + } + + return parts; +}