diff --git a/rllib/evaluation/collectors/sample_collector.py b/rllib/evaluation/collectors/sample_collector.py index 415eac9592a7..6024aafc1180 100644 --- a/rllib/evaluation/collectors/sample_collector.py +++ b/rllib/evaluation/collectors/sample_collector.py @@ -59,6 +59,7 @@ def __init__(self, @abstractmethod def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, + env_id: EnvID, policy_id: PolicyID, t: int, init_obs: TensorType) -> None: """Adds an initial obs (after reset) to this collector. @@ -204,7 +205,7 @@ def postprocess_episode(self, episode: MultiAgentEpisode, is_done: bool = False, check_dones: bool = False, - build: bool = False) -> Optional[MultiAgentBatch]: + build: bool = False) -> Optional[Union[SampleBatch, MultiAgentBatch]]: """Postprocesses all agents' trajectories in a given episode. Generates (single-trajectory) SampleBatches for all Policies/Agents and diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 3c4838627380..0876b322cdc9 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -3,7 +3,7 @@ import logging import math import numpy as np -from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.evaluation.collectors.sample_collector import SampleCollector @@ -612,7 +612,7 @@ def postprocess_episode( episode: MultiAgentEpisode, is_done: bool = False, check_dones: bool = False, - build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]: + build: bool = False) -> Optional[Union[SampleBatch, MultiAgentBatch]]: episode_id = episode.episode_id policy_collector_group = episode.batch_builder