diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index b1161d48bdb9..eb6cf060e6ac 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -512,6 +512,10 @@ def base_iterator(timeout=None): results.append(ray.get(f)) except (StopIteration, RuntimeError) as ex_i: if was_cause_by_stop_iteration(ex_i): + logger.exception( + "Encountered an exception while extracting " + "the valid data from `futures`." + ) if self.is_infinite_sequence: stoped_actors.append(a) else: @@ -520,11 +524,17 @@ def base_iterator(timeout=None): raise ex_i if results: + logger.info( + f"Gathered {len(results)} shards of batch data " + f"with {len(stoped_actors)} stopped actors." + ) yield results elif self.is_infinite_sequence and len(stoped_actors) == len(active): raise ex + logger.info(f"Kicking off {len(active)} new sampling tasks.") futures = [a.par_iter_next.remote() for a in active] else: + logger.exception("Encountered an non-StopIteration exception.") raise ex name = "{}.batch_across_shards()".format(self) diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index adf4fc12650a..01e0f20e4b26 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -158,6 +158,11 @@ def __init__(self, min_batch_size: int): self.buffer = [] self.count = 0 self.batch_start_time = None + # We would like to log per 100 received batch shards. Since each + # shard size varies, the batch count may not exactly hit the 100 + # multiplier, and this variable is used for resetting the count for + # every 100 or more shards. + self._count_for_log = 0 def _on_fetch_start(self): if self.batch_start_time is None: @@ -168,6 +173,7 @@ def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: self.buffer.append(batch) self.count += batch.count if self.count >= self.min_batch_size: + logger.info(f"Completed to gather a full batch with size {self.count}.") if self.count > self.min_batch_size * 2: logger.info("Collected more training samples than expected " "(actual={}, expected={}). ".format( @@ -181,7 +187,11 @@ def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: self.batch_start_time = None self.buffer = [] self.count = 0 + self._count_for_log = 0 return [out] + if self.count >= 100 + self._count_for_log: + logger.info(f"Gathered the partial batch with size {self.count}.") + self._count_for_log = self.count return []