diff --git a/src/ldp/alg/callbacks.py b/src/ldp/alg/callbacks.py index 403cdbc1..2fa1c127 100644 --- a/src/ldp/alg/callbacks.py +++ b/src/ldp/alg/callbacks.py @@ -4,7 +4,8 @@ import os import time from collections import defaultdict -from collections.abc import Callable, Collection, Iterable, Sequence +from collections.abc import AsyncIterator, Callable, Collection, Iterable, Sequence +from contextlib import asynccontextmanager from pathlib import Path from typing import TYPE_CHECKING, Any, cast @@ -46,9 +47,11 @@ class Callback: callback.after_agent_init_state() * while not done: callback.before_transition() * - agent.get_asv() + async with callback.during_get_asv(): + agent.get_asv() callback.after_agent_get_asv() * - env.step() + async with callback.during_env_step(): + env.step() callback.after_env_step() * callback.after_transition() * @@ -85,6 +88,13 @@ async def before_transition( async def after_agent_init_state(self, traj_id: str, init_state: Any) -> None: """Invoked by runners after agent.init_state().""" + @asynccontextmanager + async def during_get_asv( + self, traj_id: str, agent: Agent, agent_state: Any + ) -> AsyncIterator[None]: + """Context used by runners during agent.get_asv().""" + yield + async def after_agent_get_asv( self, traj_id: str, @@ -94,6 +104,13 @@ async def after_agent_get_asv( ) -> None: """Invoked by runners after agent.get_asv().""" + @asynccontextmanager + async def during_env_step( + self, traj_id: str, env: Environment + ) -> AsyncIterator[None]: + """Context used by runners during env.step().""" + yield + async def after_env_reset( self, traj_id: str, obs: list[Message], tools: list[Tool] ) -> None: diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index 0601dbc8..cbf08eb0 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -6,7 +6,7 @@ import uuid from collections import Counter from collections.abc import Callable, Iterator, Sequence -from contextlib import contextmanager, nullcontext +from contextlib import AsyncExitStack, contextmanager, nullcontext from typing import Any, TypeVar, overload from aviary.core import Environment, Message @@ -426,11 +426,16 @@ async def _take_step( timer("agent_get_asv"), reraise_exc_as(AgentError, enabled=self.catch_agent_failures), ): - ( - action, - next_agent_state, - value, - ) = await self.agent.get_asv(agent_state, obs) + async with AsyncExitStack() as stack: + for callback in self.callbacks: + await stack.enter_async_context( + callback.during_get_asv(traj_id, self.agent, agent_state) + ) + ( + action, + next_agent_state, + value, + ) = await self.agent.get_asv(agent_state, obs) with timer("after_agent_get_asv"): await asyncio.gather(*[ @@ -444,7 +449,12 @@ async def _take_step( timer("env_step"), reraise_exc_as(EnvError, enabled=self.catch_env_failures), ): - next_obs, reward, done, trunc = await env.step(action.value) + async with AsyncExitStack() as stack: + for callback in self.callbacks: + await stack.enter_async_context( + callback.during_env_step(traj_id, env) + ) + next_obs, reward, done, trunc = await env.step(action.value) with timer("after_env_step"): await asyncio.gather(*[ callback.after_env_step(traj_id, next_obs, reward, done, trunc)