Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 179 additions & 34 deletions src/nodes/shared/model-select.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"use client";
import { useState, useRef, useEffect } from "react";
import { ChevronDown, Check, Sparkles, Lock, Loader2 } from "lucide-react";
import { useState, useRef, useEffect, useMemo } from "react";
import { ChevronDown, Check, Sparkles, Lock, Loader2, Search } from "lucide-react";
import { cn } from "@/lib/utils";
import { ScrollArea } from "@/components/ui/scroll-area";
import { SubAgentModel, MODEL_DISPLAY_NAMES, MODEL_COST_MULTIPLIER } from "@/nodes/agent/enums";
import { useModels } from "@/hooks/use-models";

Expand Down Expand Up @@ -40,31 +39,13 @@ interface ModelSelectProps {

export function ModelSelect({ value, onChange, hideInherit }: ModelSelectProps) {
const [open, setOpen] = useState(false);
const [query, setQuery] = useState("");
const [highlightedIndex, setHighlightedIndex] = useState(-1);
const containerRef = useRef<HTMLDivElement>(null);
const searchInputRef = useRef<HTMLInputElement>(null);
const listRef = useRef<HTMLDivElement>(null);
const { groups, isLoading, isDisabled } = useModels();

// Close on outside click
useEffect(() => {
if (!open) return;
function handleClick(e: MouseEvent) {
if (containerRef.current && !containerRef.current.contains(e.target as Node)) {
setOpen(false);
}
}
document.addEventListener("mousedown", handleClick);
return () => document.removeEventListener("mousedown", handleClick);
}, [open]);

// Close on Escape
useEffect(() => {
if (!open) return;
function handleKey(e: KeyboardEvent) {
if (e.key === "Escape") setOpen(false);
}
document.addEventListener("keydown", handleKey);
return () => document.removeEventListener("keydown", handleKey);
}, [open]);

// Resolve display name — prefer API data, fall back to static map, then raw value
const resolveDisplayName = (modelValue: string): string => {
if (!modelValue) return MODEL_DISPLAY_NAMES[SubAgentModel.Inherit];
Expand All @@ -84,18 +65,115 @@ export function ModelSelect({ value, onChange, hideInherit }: ModelSelectProps)
};

const displayName = resolveDisplayName(value);
const normalizedQuery = query.trim().toLowerCase();

const filteredGroups = useMemo(() => {
if (!normalizedQuery) return groups;
return groups
.map((group) => ({
...group,
models: group.models.filter((m) => {
const haystack = `${group.label} ${m.displayName} ${m.value}`.toLowerCase();
return haystack.includes(normalizedQuery);
}),
}))
.filter((group) => group.models.length > 0);
}, [groups, normalizedQuery]);

const getVisibleOptionsForQuery = (nextQuery: string) => {
const normalized = nextQuery.trim().toLowerCase();
const options: string[] = [];

const nextShowInherit = !hideInherit && (!normalized || MODEL_DISPLAY_NAMES[SubAgentModel.Inherit].toLowerCase().includes(normalized));
if (nextShowInherit) options.push(SubAgentModel.Inherit);

for (const group of groups) {
for (const model of group.models) {
const haystack = `${group.label} ${model.displayName} ${model.value}`.toLowerCase();
if (!normalized || haystack.includes(normalized)) {
options.push(model.value);
}
}
}

return options;
};

const showInherit = !hideInherit && (!normalizedQuery || MODEL_DISPLAY_NAMES[SubAgentModel.Inherit].toLowerCase().includes(normalizedQuery));

const visibleOptions = getVisibleOptionsForQuery(query);

const closeDropdown = () => {
setOpen(false);
setQuery("");
setHighlightedIndex(-1);
};

const selectModel = (nextValue: string) => {
onChange(nextValue);
closeDropdown();
};

// Find which group the selected model belongs to
const selectedGroup = groups.find((g) =>
g.models.some((m) => m.value === value)
);

// Close on outside click
useEffect(() => {
if (!open) return;
function handleClick(e: MouseEvent) {
if (containerRef.current && !containerRef.current.contains(e.target as Node)) {
closeDropdown();
}
}
document.addEventListener("mousedown", handleClick);
return () => document.removeEventListener("mousedown", handleClick);
}, [open]);

// Close on Escape
useEffect(() => {
if (!open) return;
function handleKey(e: KeyboardEvent) {
if (e.key === "Escape") {
closeDropdown();
}
}
document.addEventListener("keydown", handleKey);
return () => document.removeEventListener("keydown", handleKey);
}, [open]);

useEffect(() => {
if (!open) return;
const id = requestAnimationFrame(() => searchInputRef.current?.focus());
return () => cancelAnimationFrame(id);
}, [open]);

useEffect(() => {
if (!open || highlightedIndex < 0) return;
const id = requestAnimationFrame(() => {
const el = listRef.current?.querySelector<HTMLElement>(`[data-option-index="${highlightedIndex}"]`);
el?.scrollIntoView({ block: "nearest" });
});
return () => cancelAnimationFrame(id);
}, [open, highlightedIndex]);

return (
<div ref={containerRef} className="relative">
{/* Trigger */}
<button
type="button"
onClick={() => !isDisabled && setOpen((p) => !p)}
onClick={() => {
if (isDisabled) return;
if (open) {
closeDropdown();
return;
}
const initialOptions = getVisibleOptionsForQuery("");
const initialIndex = initialOptions.indexOf(value);
setHighlightedIndex(initialIndex >= 0 ? initialIndex : initialOptions.length > 0 ? 0 : -1);
setOpen(true);
}}
disabled={isDisabled}
className={cn(
"w-full flex items-center gap-2.5 rounded-xl px-3 py-2.5",
Expand Down Expand Up @@ -142,18 +220,75 @@ export function ModelSelect({ value, onChange, hideInherit }: ModelSelectProps)
"shadow-2xl shadow-black/50",
)}
>
<ScrollArea className="max-h-95" viewportClassName="py-1">
<div
ref={listRef}
className="max-h-80 overflow-y-auto overscroll-contain py-1"
onWheelCapture={(e) => e.stopPropagation()}
>
<div className="sticky top-0 z-10 px-2 pb-2 bg-zinc-900/95 backdrop-blur-xl">
<div className="flex items-center gap-2 rounded-lg border border-zinc-700/60 bg-zinc-800/70 px-2.5 py-2">
<Search size={13} className="text-zinc-500 shrink-0" />
<input
ref={searchInputRef}
type="text"
value={query}
onChange={(e) => {
const nextQuery = e.target.value;
setQuery(nextQuery);
const nextOptions = getVisibleOptionsForQuery(nextQuery);
const nextIndex = nextOptions.indexOf(value);
setHighlightedIndex(nextIndex >= 0 ? nextIndex : nextOptions.length > 0 ? 0 : -1);
}}
onKeyDown={(e) => {
if (e.key === "ArrowDown") {
e.preventDefault();
e.stopPropagation();
setHighlightedIndex((prev) => {
if (visibleOptions.length === 0) return -1;
return prev < 0 ? 0 : Math.min(prev + 1, visibleOptions.length - 1);
});
return;
}
if (e.key === "ArrowUp") {
e.preventDefault();
e.stopPropagation();
setHighlightedIndex((prev) => {
if (visibleOptions.length === 0) return -1;
return prev <= 0 ? 0 : prev - 1;
});
return;
}
if (e.key === "Enter") {
const nextValue = visibleOptions[highlightedIndex];
if (nextValue) {
e.preventDefault();
e.stopPropagation();
selectModel(nextValue);
}
return;
}
e.stopPropagation();
}}
placeholder="Search models..."
className="w-full bg-transparent text-sm text-zinc-100 placeholder:text-zinc-500 outline-none"
/>
</div>
</div>

{/* Inherit option */}
{!hideInherit && (
{showInherit && (
<button
type="button"
onClick={() => { onChange(SubAgentModel.Inherit); setOpen(false); }}
data-option-index={0}
onMouseEnter={() => setHighlightedIndex(0)}
onClick={() => selectModel(SubAgentModel.Inherit)}
className={cn(
"w-full flex items-center gap-2.5 px-3 py-2.5 text-sm transition-colors",
"hover:bg-violet-500/10",
value === SubAgentModel.Inherit
? "text-violet-300 bg-violet-500/15 border-b border-violet-500/20"
: "text-zinc-400 border-b border-zinc-800/80"
: "text-zinc-400 border-b border-zinc-800/80",
highlightedIndex === 0 && "bg-zinc-800/80"
)}
>
<span className="w-4.5 flex items-center justify-center shrink-0">
Expand All @@ -175,7 +310,7 @@ export function ModelSelect({ value, onChange, hideInherit }: ModelSelectProps)
)}

{/* Grouped models */}
{groups.map((group) => (
{filteredGroups.map((group) => (
<div key={group.label}>
{/* Category header */}
<div className="flex items-center gap-2 px-3 pt-3 pb-1.5">
Expand All @@ -190,17 +325,21 @@ export function ModelSelect({ value, onChange, hideInherit }: ModelSelectProps)
{group.models.map((m) => {
const isSelected = value === m.value;
const cost = MODEL_COST_MULTIPLIER[m.value];
const optionIndex = visibleOptions.indexOf(m.value);
return (
<button
key={m.value}
type="button"
onClick={() => { onChange(m.value); setOpen(false); }}
data-option-index={optionIndex}
onMouseEnter={() => setHighlightedIndex(optionIndex)}
onClick={() => selectModel(m.value)}
className={cn(
"w-full flex items-center gap-2.5 px-3 py-1.75 text-[13px] transition-colors",
"hover:bg-zinc-800/80",
isSelected
? "text-zinc-100 bg-violet-500/8"
: "text-zinc-300"
: "text-zinc-300",
highlightedIndex === optionIndex && "bg-zinc-800/80"
)}
>
{/* Fixed-size alignment spacer for the check area */}
Expand All @@ -225,7 +364,13 @@ export function ModelSelect({ value, onChange, hideInherit }: ModelSelectProps)
No models available
</div>
)}
</ScrollArea>

{!isLoading && groups.length > 0 && filteredGroups.length === 0 && normalizedQuery && (
<div className="px-3 py-4 text-zinc-500 text-sm text-center">
No models match “{query.trim()}”
</div>
)}
</div>
</div>
)}
</div>
Expand Down