forked from continuedev/continue
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgui.py
202 lines (159 loc) · 7.76 KB
/
gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import asyncio
import json
from fastapi import Depends, Header, WebSocket, APIRouter
from starlette.websockets import WebSocketState, WebSocketDisconnect
from typing import Any, List, Type, TypeVar
from pydantic import BaseModel
import traceback
from uvicorn.main import Server
from .session_manager import session_manager, Session
from ..plugins.steps.core.core import DisplayErrorStep, MessageStep
from .gui_protocol import AbstractGUIProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.telemetry import posthog_logger
from ..libs.util.create_async_task import create_async_task
from ..libs.util.logging import logger
router = APIRouter(prefix="/gui", tags=["gui"])
# Graceful shutdown by closing websockets
original_handler = Server.handle_exit
class AppStatus:
should_exit = False
@staticmethod
def handle_exit(*args, **kwargs):
AppStatus.should_exit = True
logger.debug("Shutting down")
original_handler(*args, **kwargs)
Server.handle_exit = AppStatus.handle_exit
async def websocket_session(session_id: str) -> Session:
return await session_manager.get_session(session_id)
T = TypeVar("T", bound=BaseModel)
# You should probably abstract away the websocket stuff into a separate class
class GUIProtocolServer(AbstractGUIProtocolServer):
websocket: WebSocket
session: Session
sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
def __init__(self, session: Session):
self.session = session
async def _send_json(self, message_type: str, data: Any):
if self.websocket.application_state == WebSocketState.DISCONNECTED:
return
await self.websocket.send_json({
"messageType": message_type,
"data": data
})
async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:
try:
return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
except asyncio.TimeoutError:
raise Exception(
"GUI Protocol _receive_json timed out after 20 seconds")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
resp = await self._receive_json(message_type)
return resp_model.parse_obj(resp)
def on_error(self, e: Exception):
return self.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))
def handle_json(self, message_type: str, data: Any):
if message_type == "main_input":
self.on_main_input(data["input"])
elif message_type == "step_user_input":
self.on_step_user_input(data["input"], data["index"])
elif message_type == "refinement_input":
self.on_refinement_input(data["input"], data["index"])
elif message_type == "reverse_to_index":
self.on_reverse_to_index(data["index"])
elif message_type == "retry_at_index":
self.on_retry_at_index(data["index"])
elif message_type == "clear_history":
self.on_clear_history()
elif message_type == "delete_at_index":
self.on_delete_at_index(data["index"])
elif message_type == "delete_context_with_ids":
self.on_delete_context_with_ids(data["ids"])
elif message_type == "toggle_adding_highlighted_code":
self.on_toggle_adding_highlighted_code()
elif message_type == "set_editing_at_ids":
self.on_set_editing_at_ids(data["ids"])
elif message_type == "show_logs_at_index":
self.on_show_logs_at_index(data["index"])
elif message_type == "select_context_item":
self.select_context_item(data["id"], data["query"])
def on_main_input(self, input: str):
# Do something with user input
create_async_task(
self.session.autopilot.accept_user_input(input), self.on_error)
def on_reverse_to_index(self, index: int):
# Reverse the history to the given index
create_async_task(
self.session.autopilot.reverse_to_index(index), self.on_error)
def on_step_user_input(self, input: str, index: int):
create_async_task(
self.session.autopilot.give_user_input(input, index), self.on_error)
def on_refinement_input(self, input: str, index: int):
create_async_task(
self.session.autopilot.accept_refinement_input(input, index), self.on_error)
def on_retry_at_index(self, index: int):
create_async_task(
self.session.autopilot.retry_at_index(index), self.on_error)
def on_clear_history(self):
create_async_task(
self.session.autopilot.clear_history(), self.on_error)
def on_delete_at_index(self, index: int):
create_async_task(
self.session.autopilot.delete_at_index(index), self.on_error)
def on_delete_context_with_ids(self, ids: List[str]):
create_async_task(
self.session.autopilot.delete_context_with_ids(ids), self.on_error)
def on_toggle_adding_highlighted_code(self):
create_async_task(
self.session.autopilot.toggle_adding_highlighted_code(), self.on_error)
def on_set_editing_at_ids(self, ids: List[str]):
create_async_task(
self.session.autopilot.set_editing_at_ids(ids), self.on_error)
def on_show_logs_at_index(self, index: int):
name = f"continue_logs.txt"
logs = "\n\n############################################\n\n".join(
["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)
create_async_task(
self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error)
def select_context_item(self, id: str, query: str):
"""Called when user selects an item from the dropdown"""
create_async_task(
self.session.autopilot.select_context_item(id, query), self.on_error)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
try:
logger.debug(f"Received websocket connection at url: {websocket.url}")
await websocket.accept()
logger.debug("Session started")
session_manager.register_websocket(session.session_id, websocket)
protocol = GUIProtocolServer(session)
protocol.websocket = websocket
# Update any history that may have happened before connection
await protocol.session.autopilot.update_subscribers()
while AppStatus.should_exit is False:
message = await websocket.receive_text()
logger.debug(f"Received GUI message {message}")
if type(message) is str:
message = json.loads(message)
if "messageType" not in message or "data" not in message:
continue # :o
message_type = message["messageType"]
data = message["data"]
protocol.handle_json(message_type, data)
except WebSocketDisconnect as e:
logger.debug("GUI websocket disconnected")
except Exception as e:
# Log, send to PostHog, and send to GUI
logger.debug(f"ERROR in gui websocket: {e}")
err_msg = '\n'.join(traceback.format_exception(e))
posthog_logger.capture_event("gui_error", {
"error_title": e.__str__() or e.__repr__(), "error_message": err_msg})
await session.autopilot.ide.showMessage(err_msg)
raise e
finally:
logger.debug("Closing gui websocket")
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
await session_manager.persist_session(session.session_id)
await session_manager.remove_session(session.session_id)