diff --git a/apps/webapp/app/api/internal-proxy/steer-chat/route.ts b/apps/webapp/app/api/internal-proxy/steer-chat/route.ts
new file mode 100644
index 000000000..b106de84d
--- /dev/null
+++ b/apps/webapp/app/api/internal-proxy/steer-chat/route.ts
@@ -0,0 +1,54 @@
+import { NextRequest, NextResponse } from 'next/server';
+
+export async function POST(request: NextRequest) {
+ try {
+ const originalPayload = await request.json();
+ const apiKey = request.headers.get('X-Forwarded-Client-API-Key');
+
+ const targetUrl = 'https://www.neuronpedia.org/api/steer-chat';
+ const externalHeaders: HeadersInit = {
+ 'Content-Type': 'application/json',
+ };
+
+ if (apiKey) {
+ externalHeaders['X-API-Key'] = apiKey;
+ }
+
+ const externalResponse = await fetch(targetUrl, {
+ method: 'POST',
+ headers: externalHeaders,
+ body: JSON.stringify(originalPayload),
+ });
+
+ // Try to parse JSON, but if it fails, it might be a non-JSON error response
+ let data;
+ try {
+ data = await externalResponse.json();
+ } catch (e) {
+ // If JSON parsing fails, try to get text, it might be an HTML error page or plain text
+ const textError = await externalResponse.text();
+ // If the original request was not ok, and we couldn't parse JSON, return the text error
+ if (!externalResponse.ok) {
+ return NextResponse.json({ error: 'Failed to proxy request', details: textError }, { status: externalResponse.status || 500 });
+ }
+ // If it was ok, but not JSON (unlikely for an API), this is an unexpected situation
+ return NextResponse.json({ error: 'Unexpected response format from external API', details: textError }, { status: 500 });
+ }
+
+ if (!externalResponse.ok) {
+ // Forward the error response from the external API (already parsed as JSON)
+ return NextResponse.json(data, { status: externalResponse.status });
+ }
+
+ return NextResponse.json(data, { status: externalResponse.status });
+
+ } catch (error: any) {
+ console.error('Proxy error:', error);
+ // Check if the error is a TypeError from request.json() failing (e.g. empty body for GET)
+ // Though for POST, a body is expected. This is more a general guard.
+ if (error instanceof TypeError && error.message.includes('body stream already read') || error.message.includes('JSON Parse error')) {
+ return NextResponse.json({ error: 'Invalid request payload' }, { status: 400 });
+ }
+ return NextResponse.json({ error: 'Internal Server Error during proxying' }, { status: 500 });
+ }
+}
\ No newline at end of file
diff --git a/apps/webapp/app/api/proxy-steer-chat/route.ts b/apps/webapp/app/api/proxy-steer-chat/route.ts
new file mode 100644
index 000000000..25fd421f3
--- /dev/null
+++ b/apps/webapp/app/api/proxy-steer-chat/route.ts
@@ -0,0 +1,38 @@
+import { NextRequest, NextResponse } from 'next/server';
+
+export async function POST(request: NextRequest) {
+ try {
+ const body = await request.json();
+ const apiKey = request.headers.get('X-API-Key');
+
+ const targetUrl = 'https://www.neuronpedia.org/api/steer-chat';
+
+ const headers: HeadersInit = {
+ 'Content-Type': 'application/json',
+ };
+
+ if (apiKey) {
+ headers['X-API-Key'] = apiKey;
+ }
+
+ const externalResponse = await fetch(targetUrl, {
+ method: 'POST',
+ headers,
+ body: JSON.stringify(body),
+ });
+
+ const responseData = await externalResponse.json();
+
+ if (!externalResponse.ok) {
+ return NextResponse.json(responseData, { status: externalResponse.status });
+ }
+
+ return NextResponse.json(responseData, { status: externalResponse.status });
+ } catch (error) {
+ console.error('Proxy error:', error);
+ if (error instanceof Error) {
+ return NextResponse.json({ error: error.message }, { status: 500 });
+ }
+ return NextResponse.json({ error: 'An unknown error occurred' }, { status: 500 });
+ }
+}
\ No newline at end of file
diff --git a/apps/webapp/app/embed/feature-discovery/page.tsx b/apps/webapp/app/embed/feature-discovery/page.tsx
new file mode 100644
index 000000000..e828eca45
--- /dev/null
+++ b/apps/webapp/app/embed/feature-discovery/page.tsx
@@ -0,0 +1,9 @@
+import FeatureDiscoveryAssistant from '../../../components/tools/feature-discovery-assistant';
+
+export default function Page() {
+ return (
+
+
+
+ );
+}
diff --git a/apps/webapp/components/feature-selector/model-selector.tsx b/apps/webapp/components/feature-selector/model-selector.tsx
index a46d0e577..40f7911a4 100644
--- a/apps/webapp/components/feature-selector/model-selector.tsx
+++ b/apps/webapp/components/feature-selector/model-selector.tsx
@@ -12,6 +12,7 @@ export default function ModelSelector({
filterToRelease,
showUnlisted = false,
overrideModels,
+ id,
}: {
modelId: string;
modelIdChangedCallback: (modelId: string) => void;
@@ -19,6 +20,7 @@ export default function ModelSelector({
filterToRelease?: string | undefined;
showUnlisted?: boolean;
overrideModels?: string[];
+ id?: string;
}) {
const { globalModels, getSourceSetsForModelId, getInferenceEnabledForModel } = useGlobalContext();
@@ -33,7 +35,10 @@ export default function ModelSelector({
modelIdChangedCallback(newVal);
}}
>
-
+
MODEL
@@ -88,9 +93,8 @@ export default function ModelSelector({
diff --git a/apps/webapp/components/steer/shared/SimpleSteererLayout.tsx b/apps/webapp/components/steer/shared/SimpleSteererLayout.tsx
new file mode 100644
index 000000000..b2eb86052
--- /dev/null
+++ b/apps/webapp/components/steer/shared/SimpleSteererLayout.tsx
@@ -0,0 +1,382 @@
+import { Button } from '@/components/shadcn/button';
+import SteerChatMessage from '@/components/steer/chat-message';
+import { LoadingSpinner } from '@/components/svg/loading-spinner';
+import { ChatMessage } from '@/lib/utils/steer';
+import { ArrowUp, RotateCcw, EyeOff, Send, MousePointerClick } from 'lucide-react';
+import React, { useState, useEffect } from 'react';
+
+// type ExamplePrompt = { // Removed unused type
+// label: string;
+// text: string;
+// fullLabel?: string;
+// };
+
+interface SimpleSteererLayoutProps {
+ mode?: 'opensource' | 'default';
+ children?: React.ReactNode;
+ cappedHeight?: boolean;
+ normalEndRef: React.Ref;
+ steeredEndRef: React.Ref;
+ normalPanelTitle: string;
+ steeredPanelTitle: string;
+ initialNormalText?: string;
+ initialSteeredText?: string;
+ defaultChatMessages: ChatMessage[];
+ steeredChatMessages: ChatMessage[];
+ isTuning: boolean;
+ showNormalPanelFeature?: boolean;
+ typedInText: string;
+ onTypedInTextChange: (text: string) => void;
+ onSendChat: (overrideTypedInText?: string) => void;
+ onResetChat: () => void;
+ inputPlaceholder: string;
+ isSendDisabled: boolean;
+ isResetDisabled: boolean;
+ // examplePrompts: ExamplePrompt[]; // Removed
+ maxMessageLength: number;
+ showMoreOptions?: boolean;
+ onToggleMoreOptions?: () => void;
+ moreOptionsContent?: React.ReactNode;
+ showOptionsButton?: boolean;
+ apiKey?: string;
+ onApiKeyChange?: (key: string) => void;
+ shouldPulseSendButton?: boolean;
+}
+
+export default function SimpleSteererLayout({
+ mode = 'default',
+ children,
+ cappedHeight,
+ normalEndRef,
+ steeredEndRef,
+ normalPanelTitle,
+ steeredPanelTitle,
+ initialNormalText = "Hey, I'm normal!",
+ initialSteeredText = "Hey, I'm steered!",
+ defaultChatMessages,
+ steeredChatMessages,
+ isTuning,
+ showNormalPanelFeature = true,
+ typedInText,
+ onTypedInTextChange,
+ onSendChat,
+ onResetChat,
+ inputPlaceholder,
+ isSendDisabled,
+ isResetDisabled,
+ // examplePrompts, // Removed
+ maxMessageLength,
+ showMoreOptions,
+ onToggleMoreOptions,
+ moreOptionsContent,
+ showOptionsButton = true,
+ apiKey,
+ onApiKeyChange,
+ shouldPulseSendButton = false,
+}: SimpleSteererLayoutProps) {
+ const [isNormalPanelRevealed, setIsNormalPanelRevealed] = useState(mode === 'opensource' ? false : showNormalPanelFeature);
+
+
+ useEffect(() => {
+ if (mode === 'opensource' && !showNormalPanelFeature) {
+ setIsNormalPanelRevealed(false);
+ } else if (mode === 'default') {
+ setIsNormalPanelRevealed(showNormalPanelFeature);
+ }
+ }, [showNormalPanelFeature, mode]);
+
+ const userMessagesCount = defaultChatMessages.filter(msg => msg.role === 'user').length;
+ const userMessageLimitReached = userMessagesCount >= 3;
+
+ if (mode === 'opensource') {
+ return (
+
+
+ {showNormalPanelFeature && (
+
+
+ {userMessagesCount > 0 && (
+
+ )}
+
+
+ )}
+
+ {showNormalPanelFeature && isNormalPanelRevealed && (
+
+
+
+ {normalPanelTitle}
+
+
+
+ {!isTuning && defaultChatMessages.length === 0 && (
+
+ {initialNormalText}
+
+ )}
+
+ {isTuning && defaultChatMessages.some(msg => msg.role === 'user') &&
}
+
+
+ )}
+
+
+
+
+ {steeredPanelTitle}
+
+
+
+ {!isTuning && steeredChatMessages.length === 0 && (
+
+ {initialSteeredText}
+
+
+ Then, send a message below!
+
+ )}
+
+ {isTuning && steeredChatMessages.some(msg => msg.role === 'user') &&
}
+
+
+
+
+
+
Response cut off? Type "continue"!
+
+
{/* Changed items-center to items-stretch */}
+ {!userMessageLimitReached && (
+ <>
+ {
+ if (e.key === 'Enter' && !e.shiftKey && !isTuning && !isSendDisabled) {
+ onSendChat();
+ e.preventDefault();
+ }
+ }}
+ onChange={(e) => {
+ if (e.target.value.indexOf('\n') === -1) {
+ onTypedInTextChange(e.target.value);
+ }
+ }}
+ required
+ placeholder={inputPlaceholder}
+ // Adjusted padding/rounding to fit better with button
+ // Highlight the input field gold to indicate the user can type
+ // Remove gold border when tuning
+ className={`mt-0 w-full flex-1 resize-none rounded-lg border ${userMessagesCount === 0 ? 'border-amber-300' : 'border-gray-600'} bg-gray-800 px-4 py-2 text-left text-xs font-medium text-slate-200 placeholder-slate-500 shadow transition-all focus:${isTuning ? 'border-red-700' : 'border-amber-700'} focus:shadow focus:outline-none focus:ring-0 disabled:bg-gray-700 disabled:text-slate-500 sm:text-[13px]`}
+ />
+
+ >
+ )}
+
+
+
+ {/* Removed example prompts rendering */}
+
+
+ {
+ showOptionsButton && onToggleMoreOptions && (
+
+
+
+ )
+ }
+ {
+ showMoreOptions && (
+
+ {apiKey !== undefined && onApiKeyChange && (
+ onApiKeyChange && onApiKeyChange(e.target.value)}
+ disabled={isTuning}
+ className="w-full rounded-md border border-gray-600 bg-gray-700 px-3 py-2 text-sm text-slate-200 placeholder-slate-400 focus:border-red-700 focus:outline-none focus:ring-1 focus:ring-red-700 disabled:opacity-50"
+ />
+ )}
+ {children}
+
+ )
+ }
+
+ );
+ }
+
+ return (
+
+
+
+
+ {isNormalPanelRevealed && (
+
+
+
+ {normalPanelTitle}
+
+
+
+ {!isTuning && defaultChatMessages.length === 0 && (
+
{initialNormalText}
Get started below.
+ )}
+
+ {isTuning && defaultChatMessages.some(msg => msg.role === 'user') &&
}
+
+
+ )}
+
+
+
+ {steeredPanelTitle}
+
+
+
+ {!isTuning && steeredChatMessages.length === 0 && (
+
+ {initialSteeredText}
+
+ Get started below.
+
+ )}
+
+ {isTuning && steeredChatMessages.some(msg => msg.role === 'user') &&
}
+
+
+
+
+
+
+
+
{ if (e.key === 'Enter' && !e.shiftKey && !isTuning && !isSendDisabled) { onSendChat(); e.preventDefault(); } }} onChange={(e) => { if (e.target.value.indexOf('\n') === -1) { onTypedInTextChange(e.target.value); } }} required placeholder={inputPlaceholder}
+ // Highlight the input field gold to indicate the user can type
+ // Remove gold border when tuning
+ className={`mt-0 w-full flex-1 resize-none rounded-full border ${isTuning ? 'border-slate-300' : 'border-amber-500'} px-5 py-3.5 pr-10 text-left text-xs font-medium text-slate-800 placeholder-slate-400 shadow transition-all focus:${isTuning ? 'border-slate-300' : 'border-amber-700'} focus:shadow focus:outline-none focus:ring-0 disabled:bg-slate-200 sm:text-[13px]`} />
+
+
{ if (!isTuning && !isSendDisabled) { onSendChat(); } }} className={`h-8 w-8 rounded-full ${isTuning || isSendDisabled ? 'bg-slate-400' : 'bg-gBlue hover:bg-gBlue/80'} p-1.5 text-white`} />
+
+
+
+
+ {/* Removed example prompts rendering */}
+
+
+
+
+ {children && (
+
+ {children}
+
+ )}
+ {showOptionsButton && onToggleMoreOptions && (
+
+ )}
+ {showMoreOptions && (
+ <>
+ {apiKey !== undefined && onApiKeyChange && (
+
+ onApiKeyChange && onApiKeyChange(e.target.value)}
+ disabled={isTuning}
+ className="w-full rounded-md border border-slate-300 bg-white px-3 py-2 text-sm text-slate-700 placeholder-slate-400 focus:border-sky-500 focus:outline-none focus:ring-1 focus:ring-sky-500 disabled:opacity-50 disabled:bg-slate-100"
+ />
+
+ )}
+ {moreOptionsContent}
+ >
+ )}
+
+
+ );
+}
\ No newline at end of file
diff --git a/apps/webapp/components/steer/shared/SteererMoreOptionsContent.tsx b/apps/webapp/components/steer/shared/SteererMoreOptionsContent.tsx
new file mode 100644
index 000000000..a71a2c5ca
--- /dev/null
+++ b/apps/webapp/components/steer/shared/SteererMoreOptionsContent.tsx
@@ -0,0 +1,197 @@
+'use client';
+
+import {
+ STEER_N_COMPLETION_TOKENS_MAX,
+ STEER_STRENGTH_MULTIPLIER_MAX,
+ STEER_TEMPERATURE_MAX,
+} from '@/lib/utils/steer';
+
+interface SteererMoreOptionsContentProps {
+ strMultiple: number;
+ setStrMultiple: (value: number) => void;
+ steerSpecialTokens: boolean;
+ setSteerSpecialTokens: (value: boolean) => void;
+ steerTokens: number;
+ setSteerTokens: (value: number) => void;
+ temperature: number;
+ setTemperature: (value: number) => void;
+ freqPenalty: number;
+ setFreqPenalty: (value: number) => void;
+ seed: number;
+ setSeed: (value: number) => void;
+ randomSeed: boolean;
+}
+
+export default function SteererMoreOptionsContent({
+ strMultiple,
+ setStrMultiple,
+ steerSpecialTokens,
+ setSteerSpecialTokens,
+ steerTokens,
+ setSteerTokens,
+ temperature,
+ setTemperature,
+ freqPenalty,
+ setFreqPenalty,
+ seed,
+ setSeed,
+ randomSeed,
+}: SteererMoreOptionsContentProps) {
+ return (
+ <>
+
+
+ Steering Method
+
+ {/*
{
+ alert('Activation not available yet.');
+ return;
+ setSteeringMethod(value as SteeringMethod);
+ }}
+ aria-label="steering method"
+ >
+
+ Activation
+
+
+ SAE
+
+ */}
+
+
+ Advanced
+
+
+
+
+ Strength Multiple
+
+
{
+ if (
+ parseFloat(e.target.value) < 0 ||
+ parseFloat(e.target.value) > STEER_STRENGTH_MULTIPLIER_MAX
+ ) {
+ alert(`Strength multiplier must be >= 0 and <= ${STEER_STRENGTH_MULTIPLIER_MAX}`);
+ } else {
+ setStrMultiple(parseFloat(e.target.value));
+ }
+ }}
+ className="max-w-[80px] flex-1 rounded-md border-green-400 py-1 text-center text-xs text-green-800"
+ value={strMultiple}
+ />
+
+
+
+ Steer Special Tokens
+
+
+ {
+ setSteerSpecialTokens(e.target.checked);
+ }}
+ type="checkbox"
+ checked={steerSpecialTokens}
+ className="h-5 w-5 cursor-pointer rounded border-green-400 bg-green-100 py-1 text-center text-xs text-green-800 checked:bg-green-600 checked:text-white"
+ />
+
+
+
+
+
+
+
+ Generation Settings
+
+
+
+
+ Num Tokens
+
+
{
+ if (parseInt(e.target.value, 10) > STEER_N_COMPLETION_TOKENS_MAX) {
+ alert(
+ `Due to compute constraints, the current allowed max tokens is: ${STEER_N_COMPLETION_TOKENS_MAX}`,
+ );
+ } else {
+ setSteerTokens(parseInt(e.target.value, 10));
+ }
+ }}
+ className="max-w-[80px] flex-1 rounded-md border-amber-400 py-1 text-center text-xs text-amber-800"
+ value={steerTokens}
+ />
+
+
+
+ Temperature
+
+
{
+ if (parseFloat(e.target.value) > STEER_TEMPERATURE_MAX || parseFloat(e.target.value) < 0) {
+ alert(`Temperature must be >= 0 and <= ${STEER_TEMPERATURE_MAX}`);
+ } else {
+ setTemperature(parseFloat(e.target.value));
+ }
+ }}
+ className="max-w-[80px] flex-1 rounded-md border-amber-400 py-1 text-center text-xs text-amber-800"
+ value={temperature}
+ />
+
+
+
+ Freq Penalty
+
+
{
+ if (parseFloat(e.target.value) > 2 || parseFloat(e.target.value) < -2) {
+ alert('Freq penalty must be >= -2 and <= 2');
+ } else {
+ setFreqPenalty(parseFloat(e.target.value));
+ }
+ }}
+ className="max-w-[80px] flex-1 rounded-md border-amber-400 py-1 text-center text-xs text-amber-800"
+ value={freqPenalty}
+ />
+
+
+
Seed
+
{
+ if (parseInt(e.target.value, 10) > 100000000 || parseInt(e.target.value, 10) < -100000000) {
+ alert('Seed must be >= -100000000 and <= 100000000');
+ } else {
+ setSeed(parseInt(e.target.value, 10));
+ }
+ }}
+ className="max-w-[80px] flex-1 rounded-md border-amber-400 py-1 text-center text-xs text-amber-800 disabled:bg-amber-200 disabled:text-amber-400"
+ value={seed}
+ />
+
+
+
+
+ >
+ );
+}
\ No newline at end of file
diff --git a/apps/webapp/components/steer/steerer-simple.tsx b/apps/webapp/components/steer/steerer-simple.tsx
index 542bf13f5..2cd0e5bcd 100644
--- a/apps/webapp/components/steer/steerer-simple.tsx
+++ b/apps/webapp/components/steer/steerer-simple.tsx
@@ -6,772 +6,402 @@
'use client';
-import { SteerResultChat } from '@/app/api/steer-chat/route';
-import CustomTooltip from '@/components/custom-tooltip';
-import { useGlobalContext } from '@/components/provider/global-provider';
-import { Button } from '@/components/shadcn/button';
-import SteerChatMessage from '@/components/steer/chat-message';
-import { LoadingSpinner } from '@/components/svg/loading-spinner';
-import {
- ChatMessage,
- FeaturePreset,
- STEER_FREQUENCY_PENALTY,
- STEER_MAX_PROMPT_CHARS,
- STEER_N_COMPLETION_TOKENS_MAX,
- STEER_SEED,
- STEER_SPECIAL_TOKENS,
- STEER_STRENGTH_MULTIPLIER,
- STEER_STRENGTH_MULTIPLIER_MAX,
- STEER_TEMPERATURE,
- STEER_TEMPERATURE_MAX,
- SteerFeature,
- SteerPreset,
-} from '@/lib/utils/steer';
+// React
+import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; // useEffect, useRef added
+
+// External Libraries
import * as Select from '@radix-ui/react-select';
import * as Slider from '@radix-ui/react-slider';
-import { ArrowUp, ChevronDown, ChevronDownIcon, ChevronUpIcon, RotateCcw } from 'lucide-react';
-import { useEffect, useRef, useState } from 'react';
+import { ChevronDown, ChevronDownIcon, ChevronUpIcon } from 'lucide-react';
+
+// Application-specific
+import { SteerResultChat } from '@/app/api/steer-chat/route';
+import CustomTooltip from '@/components/custom-tooltip';
+import { useGlobalContext } from '@/components/provider/global-provider'; // Still needed for loadSavedSteerOutput's error
+import SimpleSteererLayout from '@/components/steer/shared/SimpleSteererLayout';
+import SteererMoreOptionsContent from '@/components/steer/shared/SteererMoreOptionsContent';
+import { useSimpleSteererLogic } from '@/hooks/useSimpleSteererLogic'; // Import the hook
+import { ChatMessage, FeaturePreset, SteerFeature } from '@/lib/utils/steer';
const MAX_MESSAGE_LENGTH_CHARS = 128;
+// Define a stable empty array for the default for excludedPresetNames
+const DEFAULT_EXCLUDED_PRESET_NAMES: readonly string[] = Object.freeze([]);
+
export default function SteererSimple({
- initialModelId,
+ initialModelId: initialModelIdFromProps, // Renamed to avoid conflict if we had local modelId state
cappedHeight,
showOptionsButton = true,
- excludedPresetNames = [],
+ excludedPresetNames = DEFAULT_EXCLUDED_PRESET_NAMES, // USE THE STABLE CONSTANT
}: {
initialModelId: string;
cappedHeight?: boolean;
showOptionsButton?: boolean;
- excludedPresetNames?: string[];
+ excludedPresetNames?: readonly string[]; // Ensure type matches constant
}) {
+ const [strengthLevel, setStrengthLevel] = useState(1);
+ const { showToastServerError } = useGlobalContext(); // For loadSavedSteerOutput
- const [isInitialPageLoad, setInitialPageLoad] = useState(true); // track this
+ // transformFeaturesForApi callback specific to this component
+ const transformFeaturesForApiCallback = useCallback(
+ (featuresToTransform: SteerFeature[] /* , strengthConfigIgnored */) =>
+ featuresToTransform.map((f) => ({
+ modelId: f.modelId,
+ layer: f.layer,
+ index: f.index,
+ explanation: f.explanation,
+ strength: (f.strength || 1) * strengthLevel, // Uses strengthLevel from component's state
+ })),
+ [strengthLevel],
+ );
- // eslint-disable-next-line @typescript-eslint/no-unused-vars
- const [modelId, setModelId] = useState(initialModelId);
- const [featurePresets, setFeaturePresets] = useState([]);
+ // Memoize initialStrengthConfig
+ const stableInitialStrengthConfig = useMemo(() => ({ strengthLevel }), [strengthLevel]);
- // INFO: on pageLoad, will crash if no initial model
- if (isInitialPageLoad && !initialModelId) {
- setModelId('gemma-2-2b-it')
- setInitialPageLoad(false)
- }
+ // Ref to hold the actual logic for onPresetFeaturesSelected
+ const onPresetFeaturesSelectedLogicRef = useRef<
+ ((currentSelectedFeatures: SteerFeature[], preset?: FeaturePreset) => Promise<{ defaultChatMessages?: ChatMessage[]; steeredChatMessages?: ChatMessage[] } | void>) | null
+ >(null);
+ const {
+ // modelId, // Use if model can be changed by this component, otherwise initialModelIdFromProps is enough for the hook
+ // setModelId, // If model can be changed
+ featurePresets,
+ selectedFeatures,
+ // setSelectedFeatures, // Managed by handlePresetChange
+ defaultChatMessages,
+ // setDefaultChatMessages, // Managed by hook's functions
+ steeredChatMessages,
+ // setSteeredChatMessages, // Managed by hook's functions
+ typedInText,
+ setTypedInText,
+ isTuning,
+ steerTokens,
+ setSteerTokens,
+ temperature,
+ setTemperature,
+ freqPenalty,
+ setFreqPenalty,
+ strMultiple,
+ setStrMultiple,
+ seed,
+ setSeed,
+ randomSeed,
+ // setRandomSeed, // No longer used in this component
+ steerSpecialTokens,
+ setSteerSpecialTokens,
+ showMoreOptions,
+ normalEndRef,
+ steeredEndRef,
+ isLoadingPresets,
+ // loadPresets, // Called internally by the hook
+ sendChat,
+ resetChatAndMessages,
+ handlePresetChange,
+ toggleMoreOptions,
+ // Destructure setters needed by loadSavedSteerOutput
+ setIsTuning, // Correctly destructure setIsTuning
+ setDefaultChatMessages: hookSetDefaultChatMessages, // Keep aliases for clarity in loadSavedSteerOutputRef if preferred
+ setSteeredChatMessages: hookSetSteeredChatMessages,
+ setTemperature: hookSetTemperature,
+ setSteerTokens: hookSetSteerTokens,
+ setFreqPenalty: hookSetFreqPenalty,
+ setSeed: hookSetSeed,
+ setStrMultiple: hookSetStrMultiple,
+ setSteerSpecialTokens: hookSetSteerSpecialTokens,
+ } = useSimpleSteererLogic<{ strengthLevel: number }>({
+ initialModelId: initialModelIdFromProps,
+ presetsApiEndpoint: '/api/steer/presets',
+ excludedPresetNames, // This will now be stable if defaulted
+ initialStrengthConfig: stableInitialStrengthConfig, // USE THE MEMOIZED OBJECT
+ transformFeaturesForApi: transformFeaturesForApiCallback,
+ // Pass a stable callback that invokes the logic from the ref
+ onPresetFeaturesSelected: useCallback(async (_currentSelectedFeatures: SteerFeature[], preset?: FeaturePreset) => {
+ if (onPresetFeaturesSelectedLogicRef.current) {
+ return onPresetFeaturesSelectedLogicRef.current(_currentSelectedFeatures, preset);
+ }
+ // Default return if ref is not set, though it should be by the time this is called
+ return { defaultChatMessages: [], steeredChatMessages: [] };
+ }, []), // This callback is stable
+ });
- function loadPresets() {
- fetch('/api/steer/presets', {
- method: 'POST',
- headers: { 'Content-Type': 'application/json' },
- body: JSON.stringify({
- modelId,
- }),
- })
- .then((response) => response.json())
- .then((data: SteerPreset) => {
- // Filter out any feature presets that have isUserVector set to true
- const filteredPresets = data.featurePresets.filter((preset) => {
- if (preset.isUserVector) {
- return false;
- }
- if (excludedPresetNames.includes(preset.name)) {
- return false;
- }
- return true;
- });
- setFeaturePresets(filteredPresets);
- })
- .catch((error) => {
- console.error(`error loading presets: ${error}`);
- });
- }
+ // Define loadSavedSteerOutput using setters from the hook
+ const loadSavedSteerOutput = useCallback(
+ async (steerOutputId: string) => {
+ if (!steerOutputId) {
+ hookSetDefaultChatMessages([]);
+ hookSetSteeredChatMessages([]);
+ return { defaultChatMessages: [], steeredChatMessages: [] };
+ }
+ setIsTuning(true);
- const [defaultChatMessages, setDefaultChatMessages] = useState([]);
- const [steeredChatMessages, setSteeredChatMessages] = useState([]);
- const [typedInText, setTypedInText] = useState('');
- const [steerTokens, setSteerTokens] = useState(48);
- const [temperature, setTemperature] = useState(STEER_TEMPERATURE);
- const [freqPenalty, setFreqPenalty] = useState(STEER_FREQUENCY_PENALTY);
- const [strMultiple, setStrMultiple] = useState(STEER_STRENGTH_MULTIPLIER);
- const [seed, setSeed] = useState(STEER_SEED);
- // eslint-disable-next-line @typescript-eslint/no-unused-vars
- const [randomSeed, setRandomSeed] = useState(false);
- const [steerSpecialTokens, setSteerSpecialTokens] = useState(STEER_SPECIAL_TOKENS);
- const [selectedFeatures, setSelectedFeatures] = useState([]);
- const [isTuning, setIsTuning] = useState(false);
- const { showToastServerError } = useGlobalContext();
- const [strengthLevel, setStrengthLevel] = useState(1);
- const normalEndRef = useRef(null);
- const steeredEndRef = useRef(null);
+ try {
+ const response = await fetch(`/api/steer-load`, {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({ steerOutputId }),
+ });
- useEffect(() => {
- if (featurePresets.length > 0) {
- setSelectedFeatures(featurePresets[0].features);
- }
- }, [featurePresets]);
+ if (!response.ok) {
+ // eslint-disable-next-line no-console
+ console.error(`Error loading saved steer output! Status: ${response.status}`);
+ setIsTuning(false);
+ hookSetDefaultChatMessages([]);
+ hookSetSteeredChatMessages([]);
+ return { defaultChatMessages: [], steeredChatMessages: [] };
+ }
- function reset() {
- setDefaultChatMessages([]);
- setSteeredChatMessages([]);
- setTypedInText('');
- }
+ const resp = (await response.json()) as SteerResultChat | null;
- async function loadSavedSteerOutput(steerOutputId: string) {
- setIsTuning(true);
- reset();
- await fetch(`/api/steer-load`, {
- method: 'POST',
- headers: { 'Content-Type': 'application/json' },
- body: JSON.stringify({
- steerOutputId,
- }),
- })
- .then((response) => {
- if (response.status !== 200) {
- // INFO: DO NOT alert to user; this error only indicates that the
- // hardcoded savedSteerOutput did not exist. Messages still work!
- console.error("Error loading saved steer output!");
- return null;
- }
- return response.json();
- })
- .then((resp: SteerResultChat | null) => {
if (resp === null) {
setIsTuning(false);
- return;
+ hookSetDefaultChatMessages([]);
+ hookSetSteeredChatMessages([]);
+ return { defaultChatMessages: [], steeredChatMessages: [] };
}
+
setIsTuning(false);
if (resp.settings) {
- setTemperature(resp.settings.temperature);
- setSteerTokens(resp.settings.n_tokens);
- setFreqPenalty(resp.settings.freq_penalty);
- setSeed(resp.settings.seed);
- setStrMultiple(resp.settings.strength_multiplier);
- setSteerSpecialTokens(resp.settings.steer_special_tokens);
- }
- // if chat template is null, we need to convert it (it's an old thing)
- if (resp.DEFAULT?.chatTemplate) {
- setDefaultChatMessages(resp.DEFAULT?.chatTemplate || []);
- setSteeredChatMessages(resp.STEERED?.chatTemplate || []);
- }
- })
- .catch((error) => {
- showToastServerError();
- setIsTuning(false);
- console.error(error);
- });
- }
-
- useEffect(() => {
- if (selectedFeatures.length > 0) {
- // find the featurepreset that matches the selected features
- const featPreset = featurePresets.find((p) =>
- p.features.find((f) =>
- selectedFeatures.find((sf) => sf.modelId === f.modelId && sf.layer === f.layer && sf.index === f.index),
- ),
- );
- loadSavedSteerOutput(featPreset?.exampleSteerOutputId || '');
- }
- }, [selectedFeatures]);
-
- function sendChat(overrideTypedInText?: string) {
- setIsTuning(true);
-
- const newDefaultChatMessages: ChatMessage[] = [
- ...defaultChatMessages,
- { content: overrideTypedInText || typedInText, role: 'user' },
- ];
- const newSteeredChatMessages: ChatMessage[] = [
- ...steeredChatMessages,
- { content: overrideTypedInText || typedInText, role: 'user' },
- ];
- // add to the chat messages (it will show up on UI as we load it)
- setDefaultChatMessages(newDefaultChatMessages);
- setSteeredChatMessages(newSteeredChatMessages);
-
- // calculate the number of characters in all the chat messages
- const defaultPromptToSendChars = newDefaultChatMessages.map((m) => m.content).join('').length;
- const steeredPromptToSendChars = newSteeredChatMessages.map((m) => m.content).join('').length;
-
- // check for character limit
- if (defaultPromptToSendChars >= STEER_MAX_PROMPT_CHARS || steeredPromptToSendChars >= STEER_MAX_PROMPT_CHARS) {
- alert('Sorry, we limit the length of each chat conversation.\nPlease click Reset to start a new conversation.');
- setIsTuning(false);
- return;
- }
-
- const selectedFeaturesStrengthOverridden = selectedFeatures.map((f) => ({
- modelId: f.modelId,
- layer: f.layer,
- index: f.index,
- explanation: f.explanation,
- strength: f.strength * strengthLevel,
- }));
- console.log(`steering with: ${JSON.stringify(selectedFeaturesStrengthOverridden)}`);
-
- // send the chat messages to the backend
- fetch(`/api/steer-chat`, {
- method: 'POST',
- headers: { 'Content-Type': 'application/json' },
- body: JSON.stringify({
- defaultChatMessages: newDefaultChatMessages,
- steeredChatMessages: newSteeredChatMessages,
- modelId,
- features: selectedFeaturesStrengthOverridden,
- temperature,
- n_tokens: steerTokens,
- freq_penalty: freqPenalty,
- seed: randomSeed ? Math.floor(Math.random() * 200000000 - 100000000) : seed,
- strength_multiplier: strMultiple,
- steer_special_tokens: steerSpecialTokens,
- }),
- })
- .then((response) => {
- if (response.status === 429 || response.status === 405) {
- alert('Sorry, we are limiting each user to 60 messages per hour. Please try again later.');
- console.log(response);
- return null;
- }
- if (response.status !== 200) {
- alert('Sorry, your message could not be sent at this time. Please try again later.');
- console.log(response);
- return null;
- }
- return response.json();
- })
- // check the response code
- .then((resp: SteerResultChat | null) => {
- if (resp === null) {
- // remove last message from chat messages UI
- setDefaultChatMessages(newDefaultChatMessages.slice(0, -1));
- setSteeredChatMessages(newSteeredChatMessages.slice(0, -1));
- setIsTuning(false);
- } else {
- setDefaultChatMessages(resp.DEFAULT?.chatTemplate || []);
- setSteeredChatMessages(resp.STEERED?.chatTemplate || []);
- setIsTuning(false);
- setTypedInText('');
+ hookSetTemperature(resp.settings.temperature);
+ hookSetSteerTokens(resp.settings.n_tokens);
+ hookSetFreqPenalty(resp.settings.freq_penalty);
+ hookSetSeed(resp.settings.seed);
+ hookSetStrMultiple(resp.settings.strength_multiplier);
+ hookSetSteerSpecialTokens(resp.settings.steer_special_tokens);
}
- })
- .catch((error) => {
+ const newDefaultMessages = resp.DEFAULT?.chatTemplate || [];
+ const newSteeredMessages = resp.STEERED?.chatTemplate || [];
+ hookSetDefaultChatMessages(newDefaultMessages);
+ hookSetSteeredChatMessages(newSteeredMessages);
+ return { defaultChatMessages: newDefaultMessages, steeredChatMessages: newSteeredMessages };
+ } catch (error) {
showToastServerError();
setIsTuning(false);
- setDefaultChatMessages(newDefaultChatMessages.slice(0, -1));
- setSteeredChatMessages(newSteeredChatMessages.slice(0, -1));
+ // eslint-disable-next-line no-console
console.error(error);
- });
- }
-
- useEffect(() => {
- loadPresets();
- setSeed(STEER_SEED);
- }, [modelId]);
-
- const scrollToBottom = () => {
- if (normalEndRef.current) {
- normalEndRef.current?.scrollTo({
- top: normalEndRef.current.scrollHeight,
- behavior: 'smooth',
- });
- }
- if (steeredEndRef.current) {
- steeredEndRef.current?.scrollTo({
- top: steeredEndRef.current.scrollHeight,
- behavior: 'smooth',
- });
- }
- };
+ hookSetDefaultChatMessages([]);
+ hookSetSteeredChatMessages([]);
+ return { defaultChatMessages: [], steeredChatMessages: [] };
+ }
+ },
+ [
+ setIsTuning,
+ hookSetDefaultChatMessages,
+ hookSetSteeredChatMessages,
+ hookSetTemperature,
+ hookSetSteerTokens,
+ hookSetFreqPenalty,
+ hookSetSeed,
+ hookSetStrMultiple,
+ hookSetSteerSpecialTokens,
+ showToastServerError,
+ ],
+ );
+ // useEffect to update the ref with the latest logic
useEffect(() => {
- if (steeredChatMessages.length > 0 || defaultChatMessages.length > 0) {
- scrollToBottom();
- }
- }, [steeredChatMessages, defaultChatMessages]);
-
- const [showMoreOptions, setShowMoreOptions] = useState(false);
-
- enum SteeringMethod {
- Activation = 'Activation',
- SAE = 'SAE',
- }
- // eslint-disable-next-line @typescript-eslint/no-unused-vars
- const [steeringMethod, setSteeringMethod] = useState(SteeringMethod.SAE);
+ onPresetFeaturesSelectedLogicRef.current = async (_currentSelectedFeatures: SteerFeature[], preset?: FeaturePreset) => {
+ if (preset?.exampleSteerOutputId && loadSavedSteerOutput) {
+ return loadSavedSteerOutput(preset.exampleSteerOutputId);
+ }
+ // If no example output, the hook's handlePresetChange already clears messages.
+ // Or, we can explicitly return empty messages.
+ return { defaultChatMessages: [], steeredChatMessages: [] };
+ };
+ }, [loadSavedSteerOutput]); // This effect depends on loadSavedSteerOutput
+
+ // const examplePrompts = [ // Removed as no longer used
+ // { label: 'Tell me about yourself.', text: 'Tell me about yourself.' },
+ // { label: 'Tell me a one line story.', text: 'Tell me a one line story.' },
+ // { label: 'Write a haiku.', text: 'Write a haiku.' },
+ // { label: 'I wish...', text: 'Complete this: I wish...', fullLabel: 'Complete this: I wish...' },
+ // ];
+
+ const steeringControlsJsx = (
+ <>
+
+
+
+ Feature
+
+
+
+ p.features.length === selectedFeatures.length && // Basic check for same number of features
+ p.features.every(pf => selectedFeatures.some(sf => sf.modelId === pf.modelId && sf.layer === pf.layer && sf.index === pf.index))
+ )?.name || (featurePresets.length > 0 && selectedFeatures.length === 0 ? featurePresets[0].name : '') // Default to first if nothing selected
+ }
+ onValueChange={(presetName) => {
+ handlePresetChange(presetName); // Use hook's handler
+ }}
+ >
+
+
+
+
+
+
+
+
+
+
+
+
+ {featurePresets.length > 0 ? (
+ featurePresets.map((preset) => (
+
+
+ {preset.name}
+
+
+ ))
+ ) : (
+
+
+ {isLoadingPresets ? 'Loading...' : 'No presets found'}
+
+
+ )}
+
+
+
+
+
+
+
+
- return (
-
-
-
-
-
-
-
- Normal Gemma
+
+ Strength
+
+
+
{
+ setStrengthLevel(value[0]);
+ }}
+ className="relative mt-1 flex h-5 w-full flex-1 cursor-pointer items-center"
+ >
+
+
-
- {!isTuning && steeredChatMessages.length === 0 && (
-
- Hey, {`I'm normal Gemma!`}
-
- Get started below.
-
- )}
-
- {isTuning &&
}
-
-
-
-
-
-
- {!isTuning && steeredChatMessages.length === 0 && (
-
- Hey, {`I'm steered Gemma!`}
-
- Get started below.
-
- )}
-
- {isTuning &&
}
-
-
-
-
-
-
-
-
{
- if (e.key === 'Enter' && !e.shiftKey && !isTuning) {
- sendChat();
- e.preventDefault();
- }
- }}
- onChange={(e) => {
- // if it's return, submit
- if (e.target.value.indexOf('\n') === -1) {
- setTypedInText(e.target.value);
- }
- }}
- required
- placeholder="Ask Gemma something..."
- className="mt-0 w-full flex-1 resize-none rounded-full border border-slate-300 px-5 py-3.5 pr-10 text-left text-xs font-medium text-slate-800 placeholder-slate-400 shadow transition-all focus:border-slate-300 focus:shadow focus:outline-none focus:ring-0 disabled:bg-slate-200 sm:text-[13px]"
- />
-
-
{
- if (!isTuning) {
- sendChat();
- }
- }}
- className={`h-8 w-8 rounded-full ${isTuning ? 'bg-slate-400' : 'bg-gBlue hover:bg-gBlue/80'
- } p-1.5 text-white`}
- />
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- Feature
-
-
- 0 ? featurePresets[0].name : 'Loading...'}
- value={
- // find where selectedFeature matches a feature in the presets
- featurePresets.find((p) =>
- selectedFeatures.find((f) =>
- p.features.find((fp) => fp.modelId === f.modelId && fp.layer === f.layer && fp.index === f.index),
- ),
- )?.name || 'Loading...'
- }
- onValueChange={(presetName) => {
- const feat = featurePresets.find((p) => p.name === presetName);
- setSelectedFeatures(feat?.features || []);
- }}
- >
-
-
-
-
-
-
-
-
-
-
-
-
- {featurePresets.length > 0 ? (
- featurePresets.map((preset) => (
-
-
- {preset.name}
-
-
- ))
- ) : (
-
-
- Loading...
-
-
- )}
-
-
-
-
-
-
-
-
-
-
- Strength
-
-
-
{
- setStrengthLevel(value[0]);
- }}
- className="relative mt-1 flex h-5 w-full flex-1 cursor-pointer items-center"
- >
-
-
-
-
-
- 0️⃣
}>0️⃣ No Steering
-
-
-
- {' '}
-
-
- ⚖️
}>⚖️ Medium Strength
-
-
{' '}
-
-
- 🤯
}>🤯 Quite A Bit
-
-
{' '}
-
-
-
-
- {`+${strengthLevel.toFixed(1)}x`}
-
-
- {/*
*/}
-
-
- {showMoreOptions && (
-
-
- Steering Method
-
- {/*
{
- alert('Activation not available yet.');
- return;
- setSteeringMethod(value as SteeringMethod);
- }}
- aria-label="steering method"
- >
-
- Activation
-
-
- SAE
-
- */}
-
-
- Advanced
-
-
-
-
- Strength Multiple
+
+
-
{
- if (
- parseFloat(e.target.value) < 0 ||
- parseFloat(e.target.value) > STEER_STRENGTH_MULTIPLIER_MAX
- ) {
- alert(`Strength multiplier must be >= 0 and <= ${STEER_STRENGTH_MULTIPLIER_MAX}`);
- } else {
- setStrMultiple(parseFloat(e.target.value));
- }
- }}
- className="max-w-[80px] flex-1 rounded-md border-green-400 py-1 text-center text-xs text-green-800"
- value={strMultiple}
- />
-
-
- Steer Special Tokens
+
-
- {
- setSteerSpecialTokens(e.target.checked);
- }}
- type="checkbox"
- checked={steerSpecialTokens}
- className="h-5 w-5 cursor-pointer rounded border-green-400 bg-green-100 py-1 text-center text-xs text-green-800 checked:bg-green-600 checked:text-white"
- />
+
{' '}
+
+
+ ⚖️
}>⚖️ Medium Strength
-
-
-
- )}
-
- {showMoreOptions && (
-
-
-
- Generation Settings
-
-
-
-
- Num Tokens
+
{' '}
+
-
{
- if (parseInt(e.target.value, 10) > STEER_N_COMPLETION_TOKENS_MAX) {
- alert(
- `Due to compute constraints, the current allowed max tokens is: ${STEER_N_COMPLETION_TOKENS_MAX}`,
- );
- } else {
- setSteerTokens(parseInt(e.target.value, 10));
- }
- }}
- className="max-w-[80px] flex-1 rounded-md border-amber-400 py-1 text-center text-xs text-amber-800"
- value={steerTokens}
- />
-
-
-
- Temperature
+
+ 🤯
}>🤯 Quite A Bit
-
{
- if (parseFloat(e.target.value) > STEER_TEMPERATURE_MAX || parseFloat(e.target.value) < 0) {
- alert(`Temperature must be >= 0 and <= ${STEER_TEMPERATURE_MAX}`);
- } else {
- setTemperature(parseFloat(e.target.value));
- }
- }}
- className="max-w-[80px] flex-1 rounded-md border-amber-400 py-1 text-center text-xs text-amber-800"
- value={temperature}
- />
-
-
-
- Freq Penalty
-
-
{
- if (parseFloat(e.target.value) > 2 || parseFloat(e.target.value) < -2) {
- alert('Freq penalty must be >= -2 and <= 2');
- } else {
- setFreqPenalty(parseFloat(e.target.value));
- }
- }}
- className="max-w-[80px] flex-1 rounded-md border-amber-400 py-1 text-center text-xs text-amber-800"
- value={freqPenalty}
- />
-
-
-
Seed
-
{
- if (parseInt(e.target.value, 10) > 100000000 || parseInt(e.target.value, 10) < -100000000) {
- alert('Seed must be >= -100000000 and <= 100000000');
- } else {
- setSeed(parseInt(e.target.value, 10));
- }
- }}
- className="max-w-[80px] flex-1 rounded-md border-amber-400 py-1 text-center text-xs text-amber-800 disabled:bg-amber-200 disabled:text-amber-400"
- value={seed}
- />
-
-
-
-
- )}
- {showOptionsButton ? (
- <>
-
- {/*
-
*/}
- >
- ) : (
-
- )}
+
+
+
+ {`+${strengthLevel.toFixed(1)}x`}
+
+
+
-
+ >
+ );
+
+ // const handleToggleMoreOptions = useCallback(() => { // Now from hook
+ // setShowMoreOptions((prev) => !prev);
+ // }, [setShowMoreOptions]);
+
+ return (
+
+ }
+ showOptionsButton={showOptionsButton} // Prop
+ >
+ {steeringControlsJsx}
+
);
}
diff --git a/apps/webapp/components/tools/feature-discovery-assistant.tsx b/apps/webapp/components/tools/feature-discovery-assistant.tsx
new file mode 100644
index 000000000..a33194cc9
--- /dev/null
+++ b/apps/webapp/components/tools/feature-discovery-assistant.tsx
@@ -0,0 +1,417 @@
+/* eslint-disable jsx-a11y/label-has-associated-control -- customComponent doc examples don't work :( */
+
+'use client';
+
+import ModelSelector from '@/components/feature-selector/model-selector';
+import { useGlobalContext } from '@/components/provider/global-provider';
+import { Button } from '@/components/shadcn/button';
+import { SearchExplanationsResponse } from '@/lib/utils/general';
+import {
+ SteerFeature,
+ STEER_TEMPERATURE,
+ STEER_FREQUENCY_PENALTY,
+ STEER_SEED,
+ STEER_STRENGTH_MULTIPLIER,
+ STEER_SPECIAL_TOKENS,
+ ChatMessage as SteerChatMessage, // Renaming to avoid conflict with local ChatMessage
+} from '@/lib/utils/steer';
+import { callSteerChatApi } from '@/lib/utils/steer-api'; // Import the new utility
+// import { ExplanationWithPartialRelations } from '@/prisma/generated/zod'; // No longer directly used in this component's props/state
+import { useState, useEffect, useRef, useCallback } from 'react';
+
+// Define a type for our structured feature suggestion
+type SuggestedSteerFeature = SteerFeature & {
+ originalExplanation: string; // To show the source explanation
+ score?: number; // Placeholder for potential future sorting/relevance score
+ isSelected: boolean; // To manage selection state
+};
+
+// Define FeatureSteeringResult type
+// This should extend SuggestedSteerFeature to include steering-specific results
+type FeatureSteeringResult = SuggestedSteerFeature & {
+ steeredText?: string;
+ steeringError?: string;
+ isSteeringLoading?: boolean;
+};
+
+export default function FeatureDiscoveryAssistant() {
+ const { getDefaultModel, getInferenceEnabledModels, showToastServerError } = useGlobalContext();
+ const [modelId, setModelId] = useState
(getDefaultModel()?.id || '');
+ const [testPrompt, setTestPrompt] = useState('');
+ const [featureQuery, setFeatureQuery] = useState('');
+ const [isLoading, setIsLoading] = useState(false); // For the main search
+ const [suggestedFeatures, setSuggestedFeatures] = useState([]);
+ // steeringResults will store the outcome for each of the 5 suggestedFeatures
+ const [steeringResults, setSteeringResults] = useState([]);
+ const [hasSteered, setHasSteered] = useState(false); // Tracks if "Steer Selected Features" has been clicked
+
+ // State for auto-steer timer
+ const [autoSteerTimerId, setAutoSteerTimerId] = useState(null);
+ const activeTimerIdRef = useRef(null); // To robustly check active timer in setTimeout
+ const [userInteracted, setUserInteracted] = useState(false);
+ const [triggerAutoSteer, setTriggerAutoSteer] = useState(false);
+
+
+ const handleSearch = async () => {
+ // Reset timer and interaction states on new search
+ if (autoSteerTimerId) {
+ clearTimeout(autoSteerTimerId);
+ }
+ setAutoSteerTimerId(null);
+ if (activeTimerIdRef.current) {
+ clearTimeout(activeTimerIdRef.current);
+ activeTimerIdRef.current = null;
+ }
+ setUserInteracted(false);
+ setHasSteered(false); // Also reset hasSteered as per instructions
+
+ if (!modelId || !testPrompt.trim() || !featureQuery.trim()) {
+ console.warn('Please select a model, enter a test prompt, and enter a feature name to search.');
+ return;
+ }
+ setIsLoading(true);
+ setSuggestedFeatures([]);
+ setSteeringResults([]);
+ // setHasSteered(false); // Already reset above
+
+ try {
+ const response = await fetch('/api/explanation/search-model', {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({ query: featureQuery, modelId, offset: 0 }),
+ });
+
+ if (!response.ok) throw new Error(`API Error: ${response.status} ${response.statusText}`);
+ const data = (await response.json()) as SearchExplanationsResponse; // Type assertion for data
+
+ if (data.results && data.results.length > 0) {
+ const topResults = data.results.slice(0, 5);
+ const formattedFeatures: SuggestedSteerFeature[] = topResults.map((exp, index) => {
+ const { neuron } = exp; // exp is an ExplanationWithPartialRelations here
+ if (!neuron || !neuron.index) {
+ console.warn('Explanation result missing neuron data:', exp); return null;
+ }
+ return {
+ modelId: neuron.modelId,
+ layer: neuron.layer,
+ index: parseInt(neuron.index, 10),
+ explanation: `Feature related to: "${exp.description}" (Original model: ${neuron.modelId}, Layer: ${neuron.layer}, Index: ${neuron.index})`,
+ strength: neuron.maxActApprox || 1,
+ originalExplanation: exp.description,
+ isSelected: index < 2, // Default select top 2
+ };
+ }).filter(Boolean) as SuggestedSteerFeature[];
+ setSuggestedFeatures(formattedFeatures);
+
+ // Initialize steeringResults for all 5 discovered features
+ const initialSteeringData: FeatureSteeringResult[] = formattedFeatures.map(feature => ({
+ ...feature,
+ isSteeringLoading: false, // Not loading initially
+ steeredText: undefined,
+ steeringError: undefined,
+ }));
+ setSteeringResults(initialSteeringData);
+
+ // Start auto-steer timer if features were found
+ if (formattedFeatures.length > 0) {
+ const newTimerId = setTimeout(() => {
+ if (!userInteracted && activeTimerIdRef.current === newTimerId) {
+ console.log('Auto-steering: Timer elapsed without user interaction.');
+ // Auto-select top 3 features
+ setSuggestedFeatures(prevFeatures =>
+ prevFeatures.map((feature, index) => ({
+ ...feature,
+ isSelected: index < 3,
+ }))
+ );
+ setSteeringResults(prevResults =>
+ prevResults.map((result, index) => ({
+ ...result,
+ isSelected: index < 3,
+ }))
+ );
+ setTriggerAutoSteer(true); // Trigger effect to call handleSteerSelectedFeatures
+ }
+ }, 5000);
+ activeTimerIdRef.current = newTimerId;
+ setAutoSteerTimerId(newTimerId);
+ }
+ } else {
+ alert('No features found for your query.');
+ }
+ } catch (error) {
+ console.error('Failed to search features:', error);
+ showToastServerError();
+ alert(`Failed to search features. ${error instanceof Error ? error.message : 'Please check console for details.'}`);
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
+ const handleSteerSelectedFeatures = useCallback(async () => {
+ const selectedForSteering = suggestedFeatures.filter(f => f.isSelected);
+ if (selectedForSteering.length === 0) {
+ alert('Please select at least one feature to steer.');
+ return;
+ }
+ setHasSteered(true);
+
+ // Set loading state only for selected features in the main steeringResults array
+ setSteeringResults(prevGlobalResults =>
+ prevGlobalResults.map(res => {
+ const isSelectedForThisOperation = selectedForSteering.some(
+ sf => sf.modelId === res.modelId && sf.layer === res.layer && sf.index === res.index
+ );
+ return isSelectedForThisOperation
+ ? { ...res, isSteeringLoading: true, steeredText: undefined, steeringError: undefined }
+ : res;
+ })
+ );
+
+ const steeringPromises = selectedForSteering.map(async (featureToSteer): Promise => {
+ try {
+ const apiFeature = {
+ modelId: featureToSteer.modelId,
+ layer: featureToSteer.layer,
+ index: featureToSteer.index,
+ strength: (featureToSteer.strength || 1) * 1.5,
+ explanation: featureToSteer.explanation,
+ };
+ const modelToSteer = (modelId !== 'gemma-2-2b') ? modelId : 'gemma-2-2b-it';
+ const steeringData = await callSteerChatApi({
+ modelId: modelToSteer,
+ defaultChatMessages: [{ role: 'user', content: testPrompt }] as SteerChatMessage[],
+ steeredChatMessages: [{ role: 'user', content: testPrompt }] as SteerChatMessage[],
+ features: [apiFeature],
+ temperature: STEER_TEMPERATURE, nTokens: 96, freqPenalty: STEER_FREQUENCY_PENALTY,
+ seed: STEER_SEED, strengthMultiplier: STEER_STRENGTH_MULTIPLIER, steerSpecialTokens: STEER_SPECIAL_TOKENS,
+ });
+ const assistantMessage = steeringData?.STEERED?.chatTemplate?.findLast(msg => msg.role === 'model' || msg.role === 'assistant')?.content;
+ return { ...featureToSteer, steeredText: assistantMessage || 'No response text.', steeringError: undefined, isSteeringLoading: false };
+ } catch (error) {
+ const errorMessage = error instanceof Error ? error.message : 'Steering request failed';
+ return { ...featureToSteer, steeredText: undefined, steeringError: errorMessage, isSteeringLoading: false };
+ }
+ });
+
+ const settledResults = await Promise.allSettled(steeringPromises);
+
+ const updatedResultsFromPromises: FeatureSteeringResult[] = settledResults.map((settledResult, index) => {
+ if (settledResult.status === 'fulfilled') {
+ return settledResult.value;
+ }
+ const { reason } = settledResult;
+ console.error('Unexpected promise rejection post-settled:', reason);
+ const errorMessage = reason instanceof Error ? reason.message : String(reason);
+ const originalFeatureForRejection = selectedForSteering[index];
+ return {
+ ...originalFeatureForRejection,
+ steeredText: undefined,
+ steeringError: `Unexpected error: ${errorMessage}`,
+ isSteeringLoading: false,
+ };
+ });
+
+ setSteeringResults(prevGlobalResults =>
+ prevGlobalResults.map(existingGlobalFeatureResult => {
+ const newlySteeredData = updatedResultsFromPromises.find(
+ nr => nr.modelId === existingGlobalFeatureResult.modelId &&
+ nr.layer === existingGlobalFeatureResult.layer &&
+ nr.index === existingGlobalFeatureResult.index
+ );
+ if (newlySteeredData) {
+ return {
+ ...existingGlobalFeatureResult,
+ ...newlySteeredData,
+ isSteeringLoading: false,
+ };
+ }
+ return existingGlobalFeatureResult;
+ })
+ );
+ }, [suggestedFeatures, modelId, testPrompt, showToastServerError]);
+
+ useEffect(() => {
+ if (triggerAutoSteer) {
+ handleSteerSelectedFeatures();
+ setTriggerAutoSteer(false); // Reset trigger
+ }
+ }, [triggerAutoSteer, handleSteerSelectedFeatures]);
+
+ // Cleanup timer on component unmount
+ useEffect(() => () => {
+ if (activeTimerIdRef.current) {
+ clearTimeout(activeTimerIdRef.current);
+ }
+ // Also clear state timer if it matches, though ref is primary for active timer
+ if (autoSteerTimerId) {
+ clearTimeout(autoSteerTimerId);
+ }
+ }, [autoSteerTimerId]); // Include autoSteerTimerId to re-run cleanup if it changes, though ref is main
+
+ const handleCheckboxChange = (changedIndex: number) => {
+ if (!userInteracted) {
+ setUserInteracted(true);
+ }
+ if (autoSteerTimerId) {
+ console.log('Auto-steering: User interacted, clearing timer.');
+ clearTimeout(autoSteerTimerId);
+ setAutoSteerTimerId(null);
+ }
+ if (activeTimerIdRef.current) {
+ clearTimeout(activeTimerIdRef.current);
+ activeTimerIdRef.current = null;
+ }
+
+ setSuggestedFeatures(prevFeatures =>
+ prevFeatures.map((feature, idx) =>
+ idx === changedIndex ? { ...feature, isSelected: !feature.isSelected } : feature
+ )
+ );
+ };
+
+ return (
+
+
Feature Discovery Assistant
+
+ {/* Inputs Section */}
+
+
+
+
+
+
+
+
+
+
E.g., "social isolation", "positive sentiment", "code generation"
+
+
+
+
+
+ {/* Discovered Features Section */}
+ {suggestedFeatures.length > 0 && (
+
+
Discovered Features:
+
+ {suggestedFeatures.map((feature, idx) => {
+ const steeringResultForCard = steeringResults.find(
+ sr => sr.modelId === feature.modelId && sr.layer === feature.layer && sr.index === feature.index
+ );
+
+ const cardClasses = [
+ "w-full sm:w-1/2 lg:w-1/3 xl:w-1/4 p-2 mb-4 rounded-md bg-slate-50 shadow-sm transition-all duration-300 ease-in-out",
+ ];
+ if (hasSteered && !feature.isSelected) {
+ cardClasses.push("filter grayscale opacity-50");
+ }
+
+ return (
+
+
+
Feature #{idx + 1}
+ handleCheckboxChange(idx)}
+ className="form-checkbox h-5 w-5 text-emerald-600 transition duration-150 ease-in-out"
+ />
+
+
+ Derived from: "{feature.originalExplanation}"
+
+
+ modelId: '{feature.modelId}', layer: '{feature.layer}', index: {feature.index}, strength: {feature.strength}
+
+
+ {/* Steering Result Display within Card */}
+ {steeringResultForCard && (steeringResultForCard.isSteeringLoading || (hasSteered && feature.isSelected)) && (
+
+ {steeringResultForCard.isSteeringLoading ? (
+
Steering in progress...
+ ) : steeringResultForCard.steeringError ? (
+
+
Steering Error:
+
{steeringResultForCard.steeringError}
+
+ ) : steeringResultForCard.steeredText ? (
+
+
Steered Output:
+
{steeringResultForCard.steeredText}
+
+ ) : (hasSteered && feature.isSelected) ? (
+
Steering completed, but no text was returned.
+ ) : null}
+
+ )}
+
+ );
+ })}
+
+
+ Select features and click "Steer Selected Features" to test their effect.
+
+
+ )}
+
+ {/* "Steer Selected Features" Button */}
+ {suggestedFeatures.length > 0 && (
+
+ )}
+
+ );
+}
diff --git a/apps/webapp/hooks/useSimpleSteererLogic.ts b/apps/webapp/hooks/useSimpleSteererLogic.ts
new file mode 100644
index 000000000..c67cac2b2
--- /dev/null
+++ b/apps/webapp/hooks/useSimpleSteererLogic.ts
@@ -0,0 +1,366 @@
+/* eslint-disable no-console */
+import { useState, useEffect, useRef, useCallback } from 'react';
+import {
+ ChatMessage,
+ FeaturePreset,
+ STEER_FREQUENCY_PENALTY,
+ STEER_MAX_PROMPT_CHARS,
+ STEER_SEED,
+ STEER_SPECIAL_TOKENS,
+ STEER_STRENGTH_MULTIPLIER,
+ STEER_TEMPERATURE,
+ SteerFeature,
+ SteerPreset,
+} from '@/lib/utils/steer';
+// import { SteerResultChat } from '@/app/api/steer-chat/route'; // Will be replaced by SteerChatApiResponse
+import { callSteerChatApi, SteerChatApiResponse } from '@/lib/utils/steer-api'; // Import the new utility
+import { useGlobalContext } from '@/components/provider/global-provider';
+
+// Define a stable empty array reference
+const EMPTY_ARRAY: readonly string[] = Object.freeze([]);
+
+interface UseSimpleSteererLogicProps {
+ initialModelId: string;
+ presetsApiEndpoint: string;
+ excludedPresetNames?: readonly string[];
+ initialStrengthConfig: TStrengthConfig;
+ transformFeaturesForApi: (features: SteerFeature[], strengthConfig: TStrengthConfig) => SteerFeature[];
+ onChatSuccess?: (response: SteerChatApiResponse | null) => void; // Updated to use SteerChatApiResponse
+ onPresetFeaturesSelected?: (
+ features: SteerFeature[],
+ preset?: FeaturePreset,
+ ) => { defaultChatMessages?: ChatMessage[]; steeredChatMessages?: ChatMessage[] } | void | Promise<{ defaultChatMessages?: ChatMessage[]; steeredChatMessages?: ChatMessage[] } | void>;
+ initialShowMoreOptions?: boolean;
+ starterPrompts?: Array<{ label: string; text: string; fullLabel?: string }>; // Added starterPrompts
+}
+
+export function useSimpleSteererLogic({
+ initialModelId, // Assuming this is used to initialize modelId state
+ presetsApiEndpoint,
+ excludedPresetNames = EMPTY_ARRAY,
+ initialStrengthConfig,
+ transformFeaturesForApi,
+ onChatSuccess,
+ onPresetFeaturesSelected,
+ initialShowMoreOptions = false,
+ starterPrompts, // Added starterPrompts
+}: UseSimpleSteererLogicProps) {
+ const [modelId, setModelId] = useState(initialModelId); // Assuming initialModelId prop is used here
+ const [featurePresets, setFeaturePresets] = useState([]);
+ const [selectedFeatures, setSelectedFeatures] = useState([]);
+ const [defaultChatMessages, setDefaultChatMessages] = useState([]);
+ const [steeredChatMessages, setSteeredChatMessages] = useState([]);
+ const [typedInText, setTypedInText] = useState('');
+ const [isTuning, setIsTuning] = useState(false);
+ const [steerTokens, setSteerTokens] = useState(48);
+ const [temperature, setTemperature] = useState(STEER_TEMPERATURE);
+ const [freqPenalty, setFreqPenalty] = useState(STEER_FREQUENCY_PENALTY);
+ const [strMultiple, setStrMultiple] = useState(STEER_STRENGTH_MULTIPLIER);
+ const [seed, setSeed] = useState(STEER_SEED);
+ const [randomSeed, setRandomSeed] = useState(false); // SteererMoreOptionsContent will manage its own checkbox state for this
+ const [steerSpecialTokens, setSteerSpecialTokens] = useState(STEER_SPECIAL_TOKENS);
+ const [showMoreOptions, setShowMoreOptions] = useState(initialShowMoreOptions);
+ const [isLoadingPresets, setIsLoadingPresets] = useState(false);
+ const [hasInitialPresetBeenApplied, setHasInitialPresetBeenApplied] = useState(false);
+
+ const normalEndRef = useRef(null);
+ const steeredEndRef = useRef(null);
+
+ const { showToastServerError } = useGlobalContext();
+
+ const loadPresets = useCallback(async () => {
+ console.log('[useSimpleSteererLogic] loadPresets function instance created/called. modelId:', modelId, 'presetsApiEndpoint:', presetsApiEndpoint, 'excludedPresetNames:', excludedPresetNames);
+ setIsLoadingPresets(true);
+ try {
+ console.log(`[useSimpleSteererLogic] Preparing to fetch presets. modelId is "${modelId}", presetsApiEndpoint is "${presetsApiEndpoint}"`);
+ if (!modelId) {
+ console.error("[useSimpleSteererLogic] loadPresets: modelId is undefined or empty. Aborting fetch.");
+ setFeaturePresets([]); // Clear presets on error or invalid modelId
+ setIsLoadingPresets(false); // Ensure loading state is reset
+ return; // Prevent fetch with undefined modelId
+ }
+ // Presets should always be fetched from the local endpoint
+ const fetchUrl = presetsApiEndpoint;
+ const headers: HeadersInit = { 'Content-Type': 'application/json' };
+
+ const response = await fetch(fetchUrl, {
+ method: 'POST',
+ headers,
+ body: JSON.stringify({ modelId }),
+ });
+ if (!response.ok) {
+ throw new Error(`Failed to fetch presets: ${response.statusText}`);
+ }
+ const data: SteerPreset = await response.json();
+ const filteredPresets = data.featurePresets.filter(
+ (preset) => !preset.isUserVector && !excludedPresetNames.includes(preset.name),
+ );
+ setFeaturePresets(filteredPresets);
+ } catch (error) {
+ // eslint-disable-next-line no-console
+ console.error(`Error loading presets: ${error}`);
+ showToastServerError();
+ setFeaturePresets([]); // Clear presets on error
+ } finally {
+ setIsLoadingPresets(false);
+ }
+ }, [modelId, presetsApiEndpoint, excludedPresetNames, showToastServerError]);
+
+ const resetChatAndMessages = useCallback(() => {
+ setDefaultChatMessages([]);
+ setSteeredChatMessages([]);
+ setTypedInText('');
+ // Note: Component-specific reset logic (like setShowNormalResponse) should be handled in the component
+ }, []);
+
+ useEffect(() => {
+ console.log('[useSimpleSteererLogic] useEffect to load presets triggered. modelId:', modelId);
+ // Reset states when modelId changes to allow re-initialization for the new model
+ resetChatAndMessages();
+ setSelectedFeatures([]);
+ setHasInitialPresetBeenApplied(false); // Allow initial preset to be applied for the new model
+ loadPresets();
+ setSeed(STEER_SEED); // Reset seed when model changes, similar to original
+ }, [modelId, loadPresets, resetChatAndMessages]); // resetChatAndMessages is stable
+
+ useEffect(() => {
+ // Auto-select the first preset's features if none are selected yet
+ // and then call the onPresetFeaturesSelected callback.
+ // Also, auto-fill the first prompt.
+ // Only run if initial auto-selection hasn't occurred AND conditions are met
+ if (!hasInitialPresetBeenApplied && featurePresets.length > 0 && selectedFeatures.length === 0 && !isLoadingPresets) {
+ const firstPreset = featurePresets[0];
+ setSelectedFeatures(firstPreset.features);
+
+ // Auto-fill chat input for the first preset
+ if (starterPrompts && starterPrompts.length > 0) {
+ setTypedInText(starterPrompts[0].text);
+ } else {
+ setTypedInText(''); // Clear if no corresponding prompt
+ }
+
+ if (onPresetFeaturesSelected) {
+ Promise.resolve(onPresetFeaturesSelected(firstPreset.features, firstPreset)).then((messages) => {
+ if (messages) {
+ setDefaultChatMessages(messages.defaultChatMessages || []);
+ setSteeredChatMessages(messages.steeredChatMessages || []);
+ }
+ });
+ }
+ setHasInitialPresetBeenApplied(true); // Mark that initial auto-selection has now happened
+ }
+ }, [
+ hasInitialPresetBeenApplied, // Add as a dependency
+ featurePresets,
+ selectedFeatures, // Still need to check its length
+ isLoadingPresets,
+ onPresetFeaturesSelected,
+ starterPrompts,
+ // Stable setters like setTypedInText, setSelectedFeatures, etc., are often omitted
+ // from deps if not strictly necessary by lint rules, assuming they are stable.
+ // Adding them explicitly for clarity if preferred by project style:
+ // setTypedInText, setSelectedFeatures, setDefaultChatMessages, setSteeredChatMessages
+ ]);
+
+ const sendChat = useCallback(
+ async (overrideTypedInText?: string) => {
+ setIsTuning(true);
+ const originalInput = overrideTypedInText || typedInText;
+ const currentText = `${originalInput} Tell me concisely.`;
+
+ if (!currentText.trim()) {
+ // eslint-disable-next-line no-alert
+ alert('Please enter a message.');
+ setIsTuning(false);
+ return;
+ }
+
+ const newDefaultChatMessages: ChatMessage[] = [...defaultChatMessages, { content: currentText, role: 'user' }];
+ const newSteeredChatMessages: ChatMessage[] = [...steeredChatMessages, { content: currentText, role: 'user' }];
+
+ setDefaultChatMessages(newDefaultChatMessages);
+ setSteeredChatMessages(newSteeredChatMessages);
+
+ const defaultPromptToSendChars = newDefaultChatMessages.map((m) => m.content).join('').length;
+ const steeredPromptToSendChars = newSteeredChatMessages.map((m) => m.content).join('').length;
+
+ if (defaultPromptToSendChars >= STEER_MAX_PROMPT_CHARS || steeredPromptToSendChars >= STEER_MAX_PROMPT_CHARS) {
+ // eslint-disable-next-line no-alert
+ alert('Sorry, we limit the length of each chat conversation.\nPlease click Reset to start a new conversation.');
+ setDefaultChatMessages(newDefaultChatMessages.slice(0, -1));
+ setSteeredChatMessages(newSteeredChatMessages.slice(0, -1));
+ setIsTuning(false);
+ return;
+ }
+
+ const featuresForApi = transformFeaturesForApi(selectedFeatures, initialStrengthConfig);
+ // eslint-disable-next-line no-console
+ console.log(`Steering with: ${JSON.stringify(featuresForApi)}`);
+
+ try {
+ // The apiKey and apiBaseUrl are now handled by callSteerChatApi
+ // The NEXT_PUBLIC_NEURONPEDIA_APIKEY environment variable will be used by callSteerChatApi
+
+ const resp = await callSteerChatApi({
+ modelId,
+ features: featuresForApi,
+ defaultChatMessages: newDefaultChatMessages,
+ steeredChatMessages: newSteeredChatMessages,
+ temperature,
+ nTokens: steerTokens,
+ freqPenalty,
+ seed: randomSeed ? Math.floor(Math.random() * 200000000 - 100000000) : seed,
+ strengthMultiplier: strMultiple,
+ steerSpecialTokens,
+ });
+
+ // Error handling (including 429, 405, and other !response.ok cases)
+ // is now managed within callSteerChatApi.
+ // If callSteerChatApi throws an error, it will be caught by the catch block below.
+
+ // const resp: SteerChatApiResponse | null = await response.json(); // This line is replaced by the callSteerChatApi call
+ if (resp === null) { // Should not happen if response.ok
+ throw new Error('Empty response from API');
+ }
+ setDefaultChatMessages(resp.DEFAULT?.chatTemplate || []);
+ setSteeredChatMessages(resp.STEERED?.chatTemplate || []);
+ setTypedInText('');
+ if (onChatSuccess) {
+ onChatSuccess(resp);
+ }
+ } catch (error) {
+ // eslint-disable-next-line no-console
+ console.error(error);
+ showToastServerError();
+ setDefaultChatMessages(newDefaultChatMessages.slice(0, -1));
+ setSteeredChatMessages(newSteeredChatMessages.slice(0, -1));
+ } finally {
+ setIsTuning(false);
+ }
+ },
+ [
+ typedInText,
+ defaultChatMessages,
+ steeredChatMessages,
+ modelId,
+ selectedFeatures,
+ initialStrengthConfig, // Hook prop, if it changes, sendChat should update
+ transformFeaturesForApi, // Hook prop
+ temperature,
+ steerTokens,
+ freqPenalty,
+ randomSeed,
+ seed,
+ strMultiple,
+ steerSpecialTokens,
+ onChatSuccess, // Hook prop
+ showToastServerError,
+ ],
+ );
+
+ const handlePresetChange = useCallback(
+ async (presetName: string) => {
+ const presetIndex = featurePresets.findIndex((p) => p.name === presetName);
+ const preset = presetIndex !== -1 ? featurePresets[presetIndex] : undefined;
+
+ if (preset) {
+ setSelectedFeatures(preset.features);
+ setHasInitialPresetBeenApplied(true); // A user has made a selection, initial auto-select is no longer primary concern
+
+ // Auto-fill chat input based on the index of the selected preset
+ if (starterPrompts && starterPrompts.length > presetIndex && presetIndex !== -1) {
+ setTypedInText(starterPrompts[presetIndex].text);
+ } else {
+ // Optionally clear or set a default if no corresponding prompt
+ setTypedInText('');
+ }
+
+ if (onPresetFeaturesSelected) {
+ // Reset messages before loading new ones from preset, or let callback decide
+ setDefaultChatMessages([]);
+ setSteeredChatMessages([]);
+ const messages = await Promise.resolve(onPresetFeaturesSelected(preset.features, preset));
+ if (messages) {
+ setDefaultChatMessages(messages.defaultChatMessages || []);
+ setSteeredChatMessages(messages.steeredChatMessages || []);
+ }
+ }
+ }
+ },
+ [
+ featurePresets,
+ onPresetFeaturesSelected,
+ starterPrompts,
+ setTypedInText,
+ setSelectedFeatures,
+ setDefaultChatMessages,
+ setSteeredChatMessages,
+ setHasInitialPresetBeenApplied // Add the new setter to the dependency array
+ ],
+ );
+
+ const toggleMoreOptions = useCallback(() => {
+ setShowMoreOptions((prev) => !prev);
+ }, []);
+
+ const scrollToBottom = useCallback(() => {
+ normalEndRef.current?.scrollTo({
+ top: normalEndRef.current.scrollHeight,
+ behavior: 'smooth',
+ });
+ steeredEndRef.current?.scrollTo({
+ top: steeredEndRef.current.scrollHeight,
+ behavior: 'smooth',
+ });
+ }, []);
+
+ useEffect(() => {
+ // Scroll logic depends on whether the normal panel is shown,
+ // which is managed by the component using the hook.
+ // For now, always try to scroll if messages exist.
+ // Components can refine this by only calling scrollToBottom when appropriate.
+ if (steeredChatMessages.length > 0 || defaultChatMessages.length > 0) {
+ scrollToBottom();
+ }
+ }, [steeredChatMessages, defaultChatMessages, scrollToBottom]);
+
+ return {
+ modelId,
+ setModelId,
+ featurePresets,
+ selectedFeatures,
+ setSelectedFeatures,
+ defaultChatMessages,
+ setDefaultChatMessages,
+ steeredChatMessages,
+ setSteeredChatMessages,
+ typedInText,
+ setTypedInText,
+ isTuning,
+ setIsTuning, // Exposed setIsTuning
+ steerTokens,
+ setSteerTokens,
+ temperature,
+ setTemperature,
+ freqPenalty,
+ setFreqPenalty,
+ strMultiple,
+ setStrMultiple,
+ seed,
+ setSeed,
+ randomSeed,
+ setRandomSeed, // Allow component to control this if needed, though MoreOptionsContent might handle its own
+ steerSpecialTokens,
+ setSteerSpecialTokens,
+ showMoreOptions,
+ normalEndRef,
+ steeredEndRef,
+ isLoadingPresets,
+ loadPresets, // Exposing for potential manual refresh, though it runs on modelId change
+ sendChat,
+ resetChatAndMessages,
+ handlePresetChange,
+ toggleMoreOptions,
+ // scrollToBottom, // Exposing this if components need more fine-grained control
+ };
+}
\ No newline at end of file
diff --git a/apps/webapp/lib/utils/steer-api.ts b/apps/webapp/lib/utils/steer-api.ts
new file mode 100644
index 000000000..5545f700a
--- /dev/null
+++ b/apps/webapp/lib/utils/steer-api.ts
@@ -0,0 +1,89 @@
+import {
+ ChatMessage,
+ SteerFeature,
+ STEER_TEMPERATURE,
+ STEER_N_COMPLETION_TOKENS,
+ STEER_FREQUENCY_PENALTY,
+ STEER_SEED,
+ STEER_STRENGTH_MULTIPLIER,
+ STEER_SPECIAL_TOKENS,
+} from './steer'; // Assuming steer.ts is in the same directory
+
+export interface SteerChatApiRequestPayload {
+ modelId: string; // The model ID for generation
+ features: SteerFeature[];
+ defaultChatMessages: ChatMessage[];
+ steeredChatMessages: ChatMessage[];
+ temperature?: number;
+ nTokens?: number;
+ freqPenalty?: number;
+ seed?: number;
+ strengthMultiplier?: number;
+ steerSpecialTokens: boolean; // Made non-optional as it's required by the API
+}
+
+export interface SteerChatApiResponse {
+ STEERED?: {
+ chatTemplate: ChatMessage[];
+ // other fields might exist but are not needed for this step
+ } | null;
+ DEFAULT?: {
+ chatTemplate: ChatMessage[];
+ // other fields might exist
+ } | null;
+ // Potentially other fields like error messages if the API returns them in a structured way
+}
+
+export async function callSteerChatApi(
+ params: SteerChatApiRequestPayload,
+): Promise {
+ const {
+ modelId,
+ features,
+ defaultChatMessages,
+ steeredChatMessages,
+ temperature = STEER_TEMPERATURE,
+ nTokens = STEER_N_COMPLETION_TOKENS,
+ freqPenalty = STEER_FREQUENCY_PENALTY,
+ seed = STEER_SEED,
+ strengthMultiplier = STEER_STRENGTH_MULTIPLIER,
+ steerSpecialTokens = STEER_SPECIAL_TOKENS,
+ } = params;
+
+ const neuronpediaApiKey = process.env.NEXT_PUBLIC_NEURONPEDIA_APIKEY;
+ let steerChatEndpoint = '/api/steer-chat';
+ const requestHeaders: HeadersInit = { 'Content-Type': 'application/json' };
+
+ if (neuronpediaApiKey && neuronpediaApiKey.trim() !== '') {
+ steerChatEndpoint = '/api/internal-proxy/steer-chat';
+ requestHeaders['X-Forwarded-Client-API-Key'] = neuronpediaApiKey.trim();
+ }
+
+ const requestBodyJson = {
+ modelId,
+ features,
+ defaultChatMessages,
+ steeredChatMessages,
+ temperature, // Assuming API accepts 'temperature' as is
+ n_tokens: nTokens,
+ freq_penalty: freqPenalty,
+ seed, // Assuming API accepts 'seed' as is
+ strength_multiplier: strengthMultiplier,
+ steer_special_tokens: steerSpecialTokens, // Ensure API receives snake_case
+ };
+
+ const response = await fetch(steerChatEndpoint, {
+ method: 'POST',
+ headers: requestHeaders,
+ body: JSON.stringify(requestBodyJson),
+ });
+
+ if (!response.ok) {
+ const errorText = await response.text();
+ throw new Error(
+ `API Error: ${response.status} ${response.statusText}. Details: ${errorText}`,
+ );
+ }
+
+ return response.json() as Promise;
+}
\ No newline at end of file
diff --git a/apps/webapp/lib/utils/steer.ts b/apps/webapp/lib/utils/steer.ts
index 3eeb0c550..d82161734 100644
--- a/apps/webapp/lib/utils/steer.ts
+++ b/apps/webapp/lib/utils/steer.ts
@@ -5,7 +5,7 @@ import { STEER_FORCE_ALLOW_INSTRUCT_MODELS } from '../env';
export const STEER_N_COMPLETION_TOKENS = 64;
export const STEER_N_COMPLETION_TOKENS_THINKING = 512;
-export const STEER_N_COMPLETION_TOKENS_MAX = 128;
+export const STEER_N_COMPLETION_TOKENS_MAX = 256;
export const STEER_N_COMPLETION_TOKENS_MAX_THINKING = 768;
export const STEER_TEMPERATURE = 0.5;
export const STEER_TEMPERATURE_MAX = 2;