diff --git a/python/ray/tests/test_actor_failures.py b/python/ray/tests/test_actor_failures.py index bf603b2197ec..1410c1caa7f2 100644 --- a/python/ray/tests/test_actor_failures.py +++ b/python/ray/tests/test_actor_failures.py @@ -989,6 +989,39 @@ def verify(): wait_for_condition(verify) +def test_exit_actor_async_actor_nested_task(shutdown_only, tmp_path): + async_temp_file = tmp_path / "async_actor.log" + async_temp_file.touch() + + @ray.remote + class AsyncActor: + def __init__(self): + def f(): + print("atexit handler") + with open(async_temp_file, "w") as f: + f.write("Async Actor\n") + + atexit.register(f) + + async def start_exit_task(self): + asyncio.create_task(self.exit()) + + async def exit(self): + exit_actor() + + a = AsyncActor.remote() + ray.get(a.__ray_ready__.remote()) + with pytest.raises(ray.exceptions.RayActorError): + ray.get(a.start_exit_task.remote()) + + def verify(): + with open(async_temp_file) as f: + assert f.readlines() == ["Async Actor\n"] + return True + + wait_for_condition(verify) + + def test_exit_actor_queued(shutdown_only): """Verify after exit_actor is called the queued tasks won't execute."""