Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rllib/evaluation/collectors/sample_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self,

@abstractmethod
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
env_id: EnvID,
policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
"""Adds an initial obs (after reset) to this collector.
Expand Down Expand Up @@ -204,7 +205,7 @@ def postprocess_episode(self,
episode: MultiAgentEpisode,
is_done: bool = False,
check_dones: bool = False,
build: bool = False) -> Optional[MultiAgentBatch]:
build: bool = False) -> Optional[Union[SampleBatch, MultiAgentBatch]]:
"""Postprocesses all agents' trajectories in a given episode.

Generates (single-trajectory) SampleBatches for all Policies/Agents and
Expand Down
4 changes: 2 additions & 2 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import math
import numpy as np
from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
Expand Down Expand Up @@ -612,7 +612,7 @@ def postprocess_episode(
episode: MultiAgentEpisode,
is_done: bool = False,
check_dones: bool = False,
build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]:
build: bool = False) -> Optional[Union[SampleBatch, MultiAgentBatch]]:
episode_id = episode.episode_id
policy_collector_group = episode.batch_builder

Expand Down