Skip to content

Commit d29cc61

Browse files
rbrenopenhands-agentenystamanape
authored
Remove while True in AgentController (#5868)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
1 parent a2e9e20 commit d29cc61

10 files changed

Lines changed: 209 additions & 155 deletions

File tree

openhands/controller/agent_controller.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
)
4848
from openhands.events.serialization.event import truncate_content
4949
from openhands.llm.llm import LLM
50-
from openhands.utils.shutdown_listener import should_continue
5150

5251
# note: RESUME is only available on web GUI
5352
TRAFFIC_CONTROL_REMINDER = (
@@ -64,7 +63,6 @@ class AgentController:
6463
confirmation_mode: bool
6564
agent_to_llm_config: dict[str, LLMConfig]
6665
agent_configs: dict[str, AgentConfig]
67-
agent_task: asyncio.Future | None = None
6866
parent: 'AgentController | None' = None
6967
delegate: 'AgentController | None' = None
7068
_pending_action: Action | None = None
@@ -109,7 +107,6 @@ def __init__(
109107
headless_mode: Whether the agent is run in headless mode.
110108
status_callback: Optional callback function to handle status updates.
111109
"""
112-
self._step_lock = asyncio.Lock()
113110
self.id = sid
114111
self.agent = agent
115112
self.headless_mode = headless_mode
@@ -199,32 +196,44 @@ async def _react_to_exception(
199196
err_id = 'STATUS$ERROR_LLM_AUTHENTICATION'
200197
self.status_callback('error', err_id, type(e).__name__ + ': ' + str(e))
201198

202-
async def start_step_loop(self):
203-
"""The main loop for the agent's step-by-step execution."""
204-
self.log('info', 'Starting step loop...')
205-
while True:
206-
if not self._is_awaiting_observation() and not should_continue():
207-
break
208-
if self._closed:
209-
break
210-
try:
211-
await self._step()
212-
except asyncio.CancelledError:
213-
self.log('debug', 'AgentController task was cancelled')
214-
break
215-
except Exception as e:
216-
traceback.print_exc()
217-
self.log('error', f'Error while running the agent: {e}')
218-
await self._react_to_exception(e)
199+
def step(self):
200+
asyncio.create_task(self._step_with_exception_handling())
219201

220-
await asyncio.sleep(0.1)
202+
async def _step_with_exception_handling(self):
203+
try:
204+
await self._step()
205+
except Exception as e:
206+
traceback.print_exc()
207+
self.log('error', f'Error while running the agent: {e}')
208+
reported = RuntimeError(
209+
'There was an unexpected error while running the agent.'
210+
)
211+
if isinstance(e, litellm.LLMError):
212+
reported = e
213+
await self._react_to_exception(reported)
221214

222-
async def on_event(self, event: Event) -> None:
215+
def should_step(self, event: Event) -> bool:
216+
if isinstance(event, Action):
217+
if isinstance(event, MessageAction) and event.source == EventSource.USER:
218+
return True
219+
return False
220+
if isinstance(event, Observation):
221+
if isinstance(event, NullObservation) or isinstance(
222+
event, AgentStateChangedObservation
223+
):
224+
return False
225+
return True
226+
return False
227+
228+
def on_event(self, event: Event) -> None:
223229
"""Callback from the event stream. Notifies the controller of incoming events.
224230
225231
Args:
226232
event (Event): The incoming event to process.
227233
"""
234+
asyncio.get_event_loop().run_until_complete(self._on_event(event))
235+
236+
async def _on_event(self, event: Event) -> None:
228237
if hasattr(event, 'hidden') and event.hidden:
229238
return
230239

@@ -237,6 +246,9 @@ async def on_event(self, event: Event) -> None:
237246
elif isinstance(event, Observation):
238247
await self._handle_observation(event)
239248

249+
if self.should_step(event):
250+
self.step()
251+
240252
async def _handle_action(self, action: Action) -> None:
241253
"""Handles actions from the event stream.
242254
@@ -487,19 +499,16 @@ async def start_delegate(self, action: AgentDelegateAction) -> None:
487499
async def _step(self) -> None:
488500
"""Executes a single step of the parent or delegate agent. Detects stuck agents and limits on the number of iterations and the task budget."""
489501
if self.get_agent_state() != AgentState.RUNNING:
490-
await asyncio.sleep(1)
491502
return
492503

493504
if self._pending_action:
494-
await asyncio.sleep(1)
495505
return
496506

497507
if self.delegate is not None:
498508
assert self.delegate != self
499-
if self.delegate.get_agent_state() == AgentState.PAUSED:
500-
# no need to check too often
501-
await asyncio.sleep(1)
502-
else:
509+
# TODO this conditional will always be false, because the parent controllers are unsubscribed
510+
# remove if it's still useless when delegation is reworked
511+
if self.delegate.get_agent_state() != AgentState.PAUSED:
503512
await self._delegate_step()
504513
return
505514

@@ -509,7 +518,6 @@ async def _step(self) -> None:
509518
extra={'msg_type': 'STEP'},
510519
)
511520

512-
# check if agent hit the resources limit
513521
stop_step = False
514522
if self.state.iteration >= self.state.max_iterations:
515523
stop_step = await self._handle_traffic_control(
@@ -522,6 +530,7 @@ async def _step(self) -> None:
522530
'budget', current_cost, self.max_budget_per_task
523531
)
524532
if stop_step:
533+
logger.warning('Stopping agent due to traffic control')
525534
return
526535

527536
if self._is_stuck():
@@ -967,7 +976,7 @@ def __repr__(self):
967976
return (
968977
f'AgentController(id={self.id}, agent={self.agent!r}, '
969978
f'event_stream={self.event_stream!r}, '
970-
f'state={self.state!r}, agent_task={self.agent_task!r}, '
979+
f'state={self.state!r}, '
971980
f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
972981
)
973982

openhands/core/loop.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ async def run_agent_until_done(
1616
the agent until it reaches a terminal state.
1717
Note that runtime must be connected before being passed in here.
1818
"""
19-
controller.agent_task = asyncio.create_task(controller.start_step_loop())
2019

2120
def status_callback(msg_type, msg_id, msg):
2221
if msg_type == 'error':
@@ -41,10 +40,3 @@ def status_callback(msg_type, msg_id, msg):
4140

4241
while controller.state.agent_state not in end_states:
4342
await asyncio.sleep(1)
44-
45-
if not controller.agent_task.done():
46-
controller.agent_task.cancel()
47-
try:
48-
await controller.agent_task
49-
except asyncio.CancelledError:
50-
pass

openhands/events/stream.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import asyncio
22
import threading
3-
from dataclasses import dataclass, field
3+
from concurrent.futures import ThreadPoolExecutor
44
from datetime import datetime
55
from enum import Enum
6+
from queue import Queue
67
from typing import Callable, Iterable
78

89
from openhands.core.logger import openhands_logger as logger
@@ -52,15 +53,26 @@ async def __aiter__(self):
5253
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
5354

5455

55-
@dataclass
5656
class EventStream:
5757
sid: str
5858
file_store: FileStore
5959
# For each subscriber ID, there is a map of callback functions - useful
6060
# when there are multiple listeners
61-
_subscribers: dict[str, dict[str, Callable]] = field(default_factory=dict)
61+
_subscribers: dict[str, dict[str, Callable]]
6262
_cur_id: int = 0
63-
_lock: threading.Lock = field(default_factory=threading.Lock)
63+
_lock: threading.Lock
64+
65+
def __init__(self, sid: str, file_store: FileStore, num_workers: int = 1):
66+
self.sid = sid
67+
self.file_store = file_store
68+
self._queue: Queue[Event] = Queue()
69+
self._thread_pools: dict[str, dict[str, ThreadPoolExecutor]] = {}
70+
self._queue_thread = threading.Thread(target=self._run_queue_loop)
71+
self._queue_thread.daemon = True
72+
self._queue_thread.start()
73+
self._subscribers = {}
74+
self._lock = threading.Lock()
75+
self._cur_id = 0
6476

6577
def __post_init__(self) -> None:
6678
try:
@@ -76,6 +88,10 @@ def __post_init__(self) -> None:
7688
if id >= self._cur_id:
7789
self._cur_id = id + 1
7890

91+
def _init_thread_loop(self):
92+
loop = asyncio.new_event_loop()
93+
asyncio.set_event_loop(loop)
94+
7995
def _get_filename_for_id(self, id: int) -> str:
8096
return get_conversation_event_filename(self.sid, id)
8197

@@ -157,15 +173,18 @@ def get_latest_event_id(self) -> int:
157173
def subscribe(
158174
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
159175
):
176+
pool = ThreadPoolExecutor(max_workers=1, initializer=self._init_thread_loop)
160177
if subscriber_id not in self._subscribers:
161178
self._subscribers[subscriber_id] = {}
179+
self._thread_pools[subscriber_id] = {}
162180

163181
if callback_id in self._subscribers[subscriber_id]:
164182
raise ValueError(
165183
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
166184
)
167185

168186
self._subscribers[subscriber_id][callback_id] = callback
187+
self._thread_pools[subscriber_id][callback_id] = pool
169188

170189
def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
171190
if subscriber_id not in self._subscribers:
@@ -179,13 +198,6 @@ def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
179198
del self._subscribers[subscriber_id][callback_id]
180199

181200
def add_event(self, event: Event, source: EventSource):
182-
try:
183-
asyncio.get_running_loop().create_task(self._async_add_event(event, source))
184-
except RuntimeError:
185-
# No event loop running...
186-
asyncio.run(self._async_add_event(event, source))
187-
188-
async def _async_add_event(self, event: Event, source: EventSource):
189201
if hasattr(event, '_id') and event.id is not None:
190202
raise ValueError(
191203
'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
@@ -199,14 +211,22 @@ async def _async_add_event(self, event: Event, source: EventSource):
199211
data = event_to_dict(event)
200212
if event.id is not None:
201213
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
202-
tasks = []
203-
for key in sorted(self._subscribers.keys()):
204-
callbacks = self._subscribers[key]
205-
for callback_id in callbacks:
206-
callback = callbacks[callback_id]
207-
tasks.append(asyncio.create_task(callback(event)))
208-
if tasks:
209-
await asyncio.wait(tasks)
214+
self._queue.put(event)
215+
216+
def _run_queue_loop(self):
217+
loop = asyncio.new_event_loop()
218+
asyncio.set_event_loop(loop)
219+
loop.run_until_complete(self._process_queue())
220+
221+
async def _process_queue(self):
222+
while should_continue():
223+
event = self._queue.get()
224+
for key in sorted(self._subscribers.keys()):
225+
callbacks = self._subscribers[key]
226+
for callback_id in callbacks:
227+
callback = callbacks[callback_id]
228+
pool = self._thread_pools[key][callback_id]
229+
pool.submit(callback, event)
210230

211231
def _callback(self, callback: Callable, event: Event):
212232
asyncio.run(callback(event))

openhands/runtime/base.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import atexit
23
import copy
34
import json
@@ -167,38 +168,40 @@ def add_env_vars(self, env_vars: dict[str, str]) -> None:
167168
f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
168169
)
169170

170-
async def on_event(self, event: Event) -> None:
171+
def on_event(self, event: Event) -> None:
171172
if isinstance(event, Action):
172-
# set timeout to default if not set
173-
if event.timeout is None:
174-
event.timeout = self.config.sandbox.timeout
175-
assert event.timeout is not None
176-
try:
177-
observation: Observation = await call_sync_from_async(
178-
self.run_action, event
179-
)
180-
except Exception as e:
181-
err_id = ''
182-
if isinstance(e, ConnectionError) or isinstance(
183-
e, AgentRuntimeDisconnectedError
184-
):
185-
err_id = 'STATUS$ERROR_RUNTIME_DISCONNECTED'
186-
logger.error(
187-
'Unexpected error while running action',
188-
exc_info=True,
189-
stack_info=True,
190-
)
191-
self.log('error', f'Problematic action: {str(event)}')
192-
self.send_error_message(err_id, str(e))
193-
self.close()
194-
return
195-
196-
observation._cause = event.id # type: ignore[attr-defined]
197-
observation.tool_call_metadata = event.tool_call_metadata
198-
199-
# this might be unnecessary, since source should be set by the event stream when we're here
200-
source = event.source if event.source else EventSource.AGENT
201-
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
173+
asyncio.get_event_loop().run_until_complete(self._handle_action(event))
174+
175+
async def _handle_action(self, event: Action) -> None:
176+
if event.timeout is None:
177+
event.timeout = self.config.sandbox.timeout
178+
assert event.timeout is not None
179+
try:
180+
observation: Observation = await call_sync_from_async(
181+
self.run_action, event
182+
)
183+
except Exception as e:
184+
err_id = ''
185+
if isinstance(e, ConnectionError) or isinstance(
186+
e, AgentRuntimeDisconnectedError
187+
):
188+
err_id = 'STATUS$ERROR_RUNTIME_DISCONNECTED'
189+
logger.error(
190+
'Unexpected error while running action',
191+
exc_info=True,
192+
stack_info=True,
193+
)
194+
self.log('error', f'Problematic action: {str(event)}')
195+
self.send_error_message(err_id, str(e))
196+
self.close()
197+
return
198+
199+
observation._cause = event.id # type: ignore[attr-defined]
200+
observation.tool_call_metadata = event.tool_call_metadata
201+
202+
# this might be unnecessary, since source should be set by the event stream when we're here
203+
source = event.source if event.source else EventSource.AGENT
204+
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
202205

203206
def clone_repo(self, github_token: str | None, selected_repository: str | None):
204207
if not github_token or not selected_repository:

0 commit comments

Comments
 (0)