diff --git a/python/ray/tests/test_actor_failures.py b/python/ray/tests/test_actor_failures.py index aa6108390753..aa1f93ad6604 100644 --- a/python/ray/tests/test_actor_failures.py +++ b/python/ray/tests/test_actor_failures.py @@ -1044,6 +1044,8 @@ def test_exit_actor_async_actor_nested_task(shutdown_only, tmp_path): assert not temp_file_atexit.exists() assert not temp_file_after_exit_actor.exists() + signal = SignalActor.remote() + @ray.remote class AsyncActor: def __init__(self): @@ -1052,24 +1054,19 @@ def f(): atexit.register(f) - async def start_exit_task(self): - asyncio.create_task(self.exit()) + async def start_exit_task(self, signal): + asyncio.create_task(self.exit(signal)) - async def exit(self): + async def exit(self, signal): + await signal.wait.remote() exit_actor() # The following code should not be executed. temp_file_after_exit_actor.touch() a = AsyncActor.remote() ray.get(a.__ray_ready__.remote()) - with pytest.raises(ray.exceptions.RayActorError) as exc_info: - ray.get(a.start_exit_task.remote()) - assert ( - # Exited when task execution returns - "exit_actor()" in str(exc_info.value) - # Exited during periodical check in worker - or "User requested to exit the actor" in str(exc_info.value) - ) + ray.get(a.start_exit_task.remote(signal)) + ray.get(signal.send.remote()) def verify(): return temp_file_atexit.exists()