diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index abcb64eff260..42b7cb78a007 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -190,3 +190,7 @@ async def save_state(self) -> Mapping[str, Any]: async def load_state(self, state: Mapping[str, Any]) -> None: """Restore agent from saved state. Default implementation for stateless agents.""" BaseState.model_validate(state) + + async def close(self) -> None: + """Called when the runtime is closed""" + pass diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py index c4ea0218916c..256f752bfa80 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py @@ -64,3 +64,7 @@ async def save_state(self) -> Mapping[str, Any]: async def load_state(self, state: Mapping[str, Any]) -> None: """Restore agent from saved state""" ... + + async def close(self) -> None: + """Called when the runtime is stopped or any stop method is called""" + ... diff --git a/python/packages/autogen-core/src/autogen_core/_agent.py b/python/packages/autogen-core/src/autogen_core/_agent.py index edb5e59b1ce3..0f37b822ff8a 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_agent.py @@ -45,3 +45,7 @@ async def load_state(self, state: Mapping[str, Any]) -> None: """ ... + + async def close(self) -> None: + """Called when the runtime is closed""" + ... diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index cfefb4ab72f8..bffb61b876bb 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -152,6 +152,9 @@ async def load_state(self, state: Mapping[str, Any]) -> None: warnings.warn("load_state not implemented", stacklevel=2) pass + async def close(self) -> None: + pass + @classmethod async def register( cls, diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index f8f3669213e4..d682c1c7beb0 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -309,6 +309,7 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: ) ) recipient_agent = await self._get_agent(recipient) + message_context = MessageContext( sender=message_envelope.sender, topic_id=None, @@ -589,10 +590,21 @@ def start(self) -> None: raise RuntimeError("Runtime is already started") self._run_context = RunContext(self) + async def close(self) -> None: + """Calls :meth:`stop` if applicable and the :meth:`Agent.close` method on all instantiated agents""" + # stop the runtime if it hasn't been stopped yet + if self._run_context is not None: + await self.stop() + # close all the agents that have been instantiated + for agent_id in self._instantiated_agents: + agent = await self._get_agent(agent_id) + await agent.close() + async def stop(self) -> None: """Immediately stop the runtime message processing loop. The currently processing message will be completed, but all others following it will be discarded.""" if self._run_context is None: raise RuntimeError("Runtime is not started") + await self._run_context.stop() self._run_context = None self._message_queue = Queue() @@ -603,6 +615,7 @@ async def stop_when_idle(self) -> None: if self._run_context is None: raise RuntimeError("Runtime is not started") await self._run_context.stop_when_idle() + self._run_context = None self._message_queue = Queue() @@ -623,6 +636,7 @@ async def stop_when(self, condition: Callable[[], bool]) -> None: if self._run_context is None: raise RuntimeError("Runtime is not started") await self._run_context.stop_when(condition) + self._run_context = None self._message_queue = Queue() diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index 16de5ccc18f6..57cef4ec4810 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -86,6 +86,8 @@ async def test_register_receives_publish(tracer_provider: TracerProvider) -> Non "autogen publish default.(default)-T", ] + await runtime.close() + @pytest.mark.asyncio async def test_register_receives_publish_with_construction(caplog: pytest.LogCaptureFixture) -> None: @@ -107,6 +109,8 @@ async def agent_factory() -> LoopbackAgent: # Check if logger has the exception. assert any("Error constructing agent" in e.message for e in caplog.records) + await runtime.close() + @pytest.mark.asyncio async def test_register_receives_publish_cascade() -> None: @@ -137,6 +141,8 @@ async def test_register_receives_publish_cascade() -> None: agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent) assert agent.num_calls == total_num_calls_expected + await runtime.close() + @pytest.mark.asyncio async def test_register_factory_explicit_name() -> None: @@ -162,6 +168,8 @@ async def test_register_factory_explicit_name() -> None: ) assert other_long_running_agent.num_calls == 0 + await runtime.close() + @pytest.mark.asyncio async def test_default_subscription() -> None: @@ -185,6 +193,8 @@ async def test_default_subscription() -> None: ) assert other_long_running_agent.num_calls == 0 + await runtime.close() + @pytest.mark.asyncio async def test_type_subscription() -> None: @@ -208,6 +218,8 @@ class LoopbackAgentWithSubscription(LoopbackAgent): ... ) assert other_long_running_agent.num_calls == 0 + await runtime.close() + @pytest.mark.asyncio async def test_default_subscription_publish_to_other_source() -> None: @@ -229,3 +241,5 @@ async def test_default_subscription_publish_to_other_source() -> None: AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription ) assert other_long_running_agent.num_calls == 1 + + await runtime.close() diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index b39ec04a3e82..4ae66e44ccf6 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -179,6 +179,7 @@ async def recv(self) -> agent_worker_pb2.Message: class GrpcWorkerAgentRuntime(AgentRuntime): + # TODO: Needs to handle agent close() call def __init__( self, host_address: str,