From 25e46a6acf95b6daab7edec5dd930e8be9db36ed Mon Sep 17 00:00:00 2001 From: Jagger <634750802@qq.com> Date: Tue, 29 Oct 2024 13:15:22 +0800 Subject: [PATCH 1/2] refactor: implement external stream protocol --- backend/app/rag/chat.py | 19 ++++++- backend/app/rag/chat_config.py | 1 + backend/app/rag/types.py | 2 + backend/requirements-dev.lock | 10 +++- .../src/components/chat/chat-controller.ts | 40 +++++++------ .../chat/chat-message-controller.ts | 56 ++++++++++++++----- .../src/components/chat/chat-stream-state.ts | 13 ++++- .../message-annotation-history-stackvm.tsx | 24 ++++---- 8 files changed, 117 insertions(+), 48 deletions(-) diff --git a/backend/app/rag/chat.py b/backend/app/rag/chat.py index 3dc17940d..a1ca86473 100644 --- a/backend/app/rag/chat.py +++ b/backend/app/rag/chat.py @@ -508,12 +508,22 @@ def _external_chat(self) -> Generator[ChatEvent | str, None, None]: ) stream_chat_api_url = self.chat_engine_config.external_engine_config.stream_chat_api_url - logger.debug(f"Chatting with external chat engine (api_url: {stream_chat_api_url}) to answer for user question: {self.user_question}") + stream_chat_type = self.chat_engine_config.external_engine_config.type + logger.debug(f"Chatting with external chat engine (api_url: {stream_chat_api_url}, type: {stream_chat_type}) to answer for user question: {self.user_question}") chat_params = { "goal": self.user_question } res = requests.post(stream_chat_api_url, json=chat_params, stream=True) + yield ChatEvent( + event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, + payload=ChatStreamMessagePayload( + state=ChatMessageSate.EXTERNAL_STREAM_START, + display="Using external engine", + context={ 'type': stream_chat_type }, + ), + ) + # Notice: External type chat engine doesn't support non-streaming mode for now. response_text = "" for line in res.iter_lines(): @@ -527,6 +537,13 @@ def _external_chat(self) -> Generator[ChatEvent | str, None, None]: yield line + b'\n' + yield ChatEvent( + event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, + payload=ChatStreamMessagePayload( + state=ChatMessageSate.FINISHED, + ), + ) + db_assistant_message.content = response_text db_assistant_message.updated_at = datetime.now(UTC) db_assistant_message.finished_at = datetime.now(UTC) diff --git a/backend/app/rag/chat_config.py b/backend/app/rag/chat_config.py index 3d193fd2d..ca3860707 100644 --- a/backend/app/rag/chat_config.py +++ b/backend/app/rag/chat_config.py @@ -75,6 +75,7 @@ class KnowledgeGraphOption(BaseModel): class ExternalChatEngine(BaseModel): stream_chat_api_url: str = None + type: str = 'StackVM' class ChatEngineConfig(BaseModel): diff --git a/backend/app/rag/types.py b/backend/app/rag/types.py index 20e9110d8..bb389c3d7 100644 --- a/backend/app/rag/types.py +++ b/backend/app/rag/types.py @@ -59,3 +59,5 @@ class ChatMessageSate(int, enum.Enum): SEARCH_RELATED_DOCUMENTS = 4 GENERATE_ANSWER = 5 FINISHED = 9 + # See https://github.com/pingcap/tidb.ai/issues/345 + EXTERNAL_STREAM_START = 10 diff --git a/backend/requirements-dev.lock b/backend/requirements-dev.lock index 7f5bca754..b8d26ccd9 100644 --- a/backend/requirements-dev.lock +++ b/backend/requirements-dev.lock @@ -7,6 +7,7 @@ # all-features: false # with-sources: false # generate-hashes: false +# universal: true aiohttp==3.9.5 # via datasets @@ -89,6 +90,11 @@ cohere==5.6.2 # via llama-index-embeddings-cohere # via llama-index-postprocessor-cohere-rerank colorama==0.4.6 + # via click + # via colorlog + # via pytest + # via tqdm + # via uvicorn colorlog==6.8.2 # via optuna cryptography==42.0.8 @@ -573,6 +579,8 @@ python-pptx==1.0.2 pytz==2024.1 # via flower # via pandas +pywin32==308 ; platform_system == 'Windows' + # via portalocker pyyaml==6.0.1 # via datasets # via huggingface-hub @@ -726,7 +734,7 @@ urllib3==2.2.1 # via types-requests uvicorn==0.30.3 # via fastapi -uvloop==0.19.0 +uvloop==0.19.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32' # via uvicorn vine==5.1.0 # via amqp diff --git a/frontend/app/src/components/chat/chat-controller.ts b/frontend/app/src/components/chat/chat-controller.ts index 5f9947b23..61f3fa490 100644 --- a/frontend/app/src/components/chat/chat-controller.ts +++ b/frontend/app/src/components/chat/chat-controller.ts @@ -1,5 +1,5 @@ import { chat, type Chat, type ChatMessage, type PostChatParams } from '@/api/chats'; -import { BaseChatMessageController, ChatMessageController, LegacyChatMessageController, type OngoingState, StackVMChatMessageController } from '@/components/chat/chat-message-controller'; +import { BaseChatMessageController, ChatMessageController, ExternalChatMessageController, LegacyChatMessageController, type OngoingState } from '@/components/chat/chat-message-controller'; import { AppChatStreamState, type BaseAnnotation, chatDataPartSchema, fixChatInitialData, type StackVMState } from '@/components/chat/chat-stream-state'; import type { GtagFn } from '@/components/gtag-provider'; import { getErrorMessage } from '@/lib/errors'; @@ -34,7 +34,7 @@ export interface ChatControllerEventsMap = BaseAnnotation> extends EventEmitter> { public chat: Chat | undefined; - private _messages: Map = new Map(); + private _messages: Map = new Map(); private _postParams: Omit | undefined = undefined; private _postError: unknown = undefined; @@ -140,7 +140,7 @@ export class ChatController a.message.ordinal - b.message.ordinal); } @@ -163,7 +163,7 @@ export class ChatController['parse']>) { + _processPart (ongoingMessageController: ChatMessageController | ExternalChatMessageController | undefined, part: ReturnType['parse']>) { switch (part.type) { case 'data': // Data part contains chat and chat_message info from server. will be sent twice (beginning and finished). @@ -240,11 +240,12 @@ export class ChatController['parse']>): ChatMessageController | StackVMChatMessageController { + private _processDataPart (ongoingMessageController: ChatMessageController | ExternalChatMessageController | undefined, part: ReturnType['parse']>): ChatMessageController | ExternalChatMessageController { const { chat, user_message, assistant_message } = chatDataPartSchema.parse(fixChatInitialData(part.value[0])); this.updateChat(chat); this.upsertMessage(user_message); @@ -267,34 +268,41 @@ export class ChatController['parse']>) { + private _processMessageAnnotationPart (ongoingMessageController: ChatMessageController | ExternalChatMessageController | undefined, part: ReturnType['parse']>) { assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part); const annotation = ongoingMessageController.parseAnnotation(part.value[0]); - ongoingMessageController.applyStreamAnnotation(annotation as never); + if (annotation !== null) { + ongoingMessageController.applyStreamAnnotation(annotation as never); + } } - private _processTextPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>) { + private _processTextPart (ongoingMessageController: ChatMessageController | ExternalChatMessageController | undefined, part: ReturnType['parse']>) { if (part.value) { // ignore leading empty chunks. assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part); ongoingMessageController.applyDelta(part.value); } } - private _processErrorPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>) { + private _processErrorPart (ongoingMessageController: ChatMessageController | ExternalChatMessageController | undefined, part: ReturnType['parse']>) { assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part); ongoingMessageController.applyError(part.value); } - private _processToolCallPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>) { + private _processToolCallPart (ongoingMessageController: ChatMessageController | ExternalChatMessageController | undefined, part: ReturnType['parse']>) { assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part); ongoingMessageController.applyToolCall(part.value); } - private _processToolResultPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>) { + private _processToolResultPart (ongoingMessageController: ChatMessageController | ExternalChatMessageController | undefined, part: ReturnType['parse']>) { assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part); ongoingMessageController.applyToolResult(part.value); } + private _processMessageFinishPart (ongoingMessageController: ChatMessageController | ExternalChatMessageController | undefined, part: ReturnType['parse']>) { + assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part); + ongoingMessageController.applyFinishMessage(part.value); + } + private createMessage (message: ChatMessage, initialOngoingState?: true) { if (!this.chat?.engine_options) { throw new Error('Unable to decide which chat engine used.'); @@ -315,7 +323,7 @@ export class ChatController) { - const controller = new StackVMChatMessageController(message, initialOngoingState); + const controller = new ExternalChatMessageController(message, initialOngoingState); this._messages.set(message.id, controller); this.emit('message-loaded', controller as any); return controller; diff --git a/frontend/app/src/components/chat/chat-message-controller.ts b/frontend/app/src/components/chat/chat-message-controller.ts index 97bfb12a0..def1a6620 100644 --- a/frontend/app/src/components/chat/chat-message-controller.ts +++ b/frontend/app/src/components/chat/chat-message-controller.ts @@ -24,15 +24,17 @@ export interface ChatMessageControllerEventsMap { 'stream-tool-call': [id: string, name: string, args: any]; 'stream-tool-result': [id: string, result: any]; + + 'stream-state-change': []; } export abstract class BaseChatMessageController< State, - Annotation extends BaseAnnotation + Annotation extends BaseAnnotation, > extends EventEmitter> { protected _message: ChatMessage; protected _ongoing: OngoingState | undefined; - protected _ongoingHistory: OngoingStateHistoryItem[] | undefined; + protected _ongoingHistory: OngoingStateHistoryItem[] | undefined; public readonly role: ChatMessageRole; public readonly id: number; @@ -137,6 +139,9 @@ export abstract class BaseChatMessageController< this.emit('stream-tool-result', toolCallId, result); } + applyFinishMessage (_: unknown) { + } + finish () { this._ongoing = undefined; this.emit('stream-finished', this._message); @@ -155,7 +160,7 @@ export abstract class BaseChatMessageController< return this._ongoingHistory; } - abstract parseAnnotation (raw: unknown): Annotation; + abstract parseAnnotation (raw: unknown): Annotation | null; abstract createInitialOngoingState (): OngoingState; @@ -164,7 +169,7 @@ export abstract class BaseChatMessageController< protected abstract _polishMessage (message: ChatMessage, ongoing: OngoingState, annotation: Annotation): ChatMessage } -export type ChatMessageController = LegacyChatMessageController | StackVMChatMessageController; +export type ChatMessageController = LegacyChatMessageController | ExternalChatMessageController; export type ChatMessageControllerAnnotationState = C extends BaseChatMessageController ? State : never; export class LegacyChatMessageController extends BaseChatMessageController { @@ -209,8 +214,9 @@ export class LegacyChatMessageController extends BaseChatMessageController { +export class ExternalChatMessageController extends BaseChatMessageController { readonly version = 'StackVM'; + private externalStreamStart = false; applyToolCall (payload: { toolCallId: string; toolName: string; args: any }) { super.applyToolCall(payload); @@ -243,14 +249,34 @@ export class StackVMChatMessageController extends BaseChatMessageController { @@ -306,8 +332,8 @@ export class StackVMChatMessageController extends BaseChatMessageController, annotation: StackVMStateAnnotation): ChatMessage { // FIX Initial state // First step reasoning finished with PC = 1, we need to insert a PC = 0 state first. - if (annotation.state.state.program_counter === 1) { - if (!this._ongoingHistory?.find(item => item.state.state.state.program_counter === 0)) { + if (annotation.state && annotation.state.state.program_counter === 1) { + if (!this._ongoingHistory?.find(item => item.state.state?.state.program_counter === 0)) { const lastState = { state: { state: { @@ -329,4 +355,4 @@ export class StackVMChatMessageController extends BaseChatMessageController { + context: { type: 'StackVM' } +} + export type ChatMessageAnnotation = - BaseAnnotation> + BaseAnnotation> | TraceAnnotation | SourceNodesAnnotation - | RefineQuestionAnnotation; + | RefineQuestionAnnotation + | ExternalStreamStartAnnotation; -export interface StackVMStateAnnotation extends BaseAnnotation { +export interface StackVMStateAnnotation extends BaseAnnotation { } export type ChatInitialData = { diff --git a/frontend/app/src/components/chat/message-annotation-history-stackvm.tsx b/frontend/app/src/components/chat/message-annotation-history-stackvm.tsx index d005322a9..6ee8c2760 100644 --- a/frontend/app/src/components/chat/message-annotation-history-stackvm.tsx +++ b/frontend/app/src/components/chat/message-annotation-history-stackvm.tsx @@ -1,5 +1,5 @@ import { useChatMessageField, useChatMessageStreamHistoryStates, useChatMessageStreamState } from '@/components/chat/chat-hooks'; -import { type OngoingState, type OngoingStateHistoryItem, StackVMChatMessageController } from '@/components/chat/chat-message-controller'; +import { ExternalChatMessageController, type OngoingState, type OngoingStateHistoryItem } from '@/components/chat/chat-message-controller'; import type { StackVMState, StackVMToolCall } from '@/components/chat/chat-stream-state'; import { isNotFinished } from '@/components/chat/utils'; import { DiffSeconds } from '@/components/diff-seconds'; @@ -11,7 +11,7 @@ import { motion, type Target } from 'framer-motion'; import { CheckCircleIcon, ChevronUpIcon, ClockIcon, InfoIcon, Loader2Icon, SearchIcon } from 'lucide-react'; import { useEffect, useMemo, useState } from 'react'; -export function StackVMMessageAnnotationHistory ({ message }: { message: StackVMChatMessageController | undefined }) { +export function StackVMMessageAnnotationHistory ({ message }: { message: ExternalChatMessageController | undefined }) { const [show, setShow] = useState(true); const history = useChatMessageStreamHistoryStates(message); const current = useChatMessageStreamState(message); @@ -23,7 +23,7 @@ export function StackVMMessageAnnotationHistory ({ message }: { message: StackVM if (current) { return current.state.plan_id; } - return history?.[0]?.state.state.plan_id; + return history?.[0]?.state.state?.plan_id; }, [history, current]); useEffect(() => { @@ -53,13 +53,13 @@ export function StackVMMessageAnnotationHistory ({ message }: { message: StackVM className="text-sm mt-4" > {history?.map((item, index, history) => ( - + ))} {error && } {current && !current.finished && } {stackVMPlanId &&
- Visit StackVM to see more details + Visit StackVM to see more details
}