diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 1a277ccf4945..1a18a8dfbbc8 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -8,6 +8,7 @@ from typing import TypeVar, Generic, Iterable, List, Callable, Any, Iterator import ray +from ray.memory_monitor import RayOutOfMemoryError from ray.util.iter_metrics import MetricsContext, SharedMetrics logger = logging.getLogger(__name__) @@ -525,6 +526,9 @@ def base_iterator(timeout=None): yield _NextValueNotReady() except TimeoutError: yield _NextValueNotReady() + # Propagate OOM exceptions up the stack + except RayOutOfMemoryError: + raise except (StopIteration, RuntimeError) as ex: if was_cause_by_stop_iteration(ex): # If we are streaming (infinite sequence) then @@ -543,14 +547,19 @@ def base_iterator(timeout=None): else: active.remove(a) else: + # ex_is is never a StopIteration since was_cause_by_stop_iteration + # in the if part of the clause will catch all StopIterations raise ex_i if results: yield results elif self.is_infinite_sequence and len(stoped_actors) == len(active): - raise ex + if not isinstance(ex, StopIteration): + raise ex futures = [a.par_iter_next.remote() for a in active] else: + # ex is never a StopIteration since was_cause_by_stop_iteration + # in the if part of the clause will catch all StopIterations raise ex name = f"{self}.batch_across_shards()" diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 2a39ea93a46a..995d009f9ec7 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -431,6 +431,26 @@ # yapf: enable +def is_memory_error(e: Exception) -> bool: + """Check if an exception occurred due to a process running out of memory.""" + memory_error_names = [ + "ray.memory_monitor.RayOutOfMemoryError", + "RayOutOfMemoryError", + ] + ename = type(e).__name__ + + if ename in memory_error_names: + return True + + msg_list = list(filter(lambda s: len(s) > 0, str(e).split("\n"))) + + if ename.startswith("RayTaskError"): + return any( + any(ename in msg for msg in msg_list) for ename in memory_error_names + ) + return False + + @DeveloperAPI def with_common_config( extra_config: PartialTrainerConfigDict) -> TrainerConfigDict: @@ -554,11 +574,18 @@ def train(self) -> ResultDict: for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): try: result = Trainable.train(self) - except RayError as e: + except Exception as e: if self.config["ignore_worker_failures"]: - logger.exception( - "Error in train call, attempting to recover") - self._try_recover() + # do not retry in case of OOM errors + if issubclass(e, RayError): + if not is_memory_error(e): + logger.exception( + "Error in train call, attempting to recover") + self._try_recover() + else: + logger.exception("Not attempting to recover from error in train call " + "since it was caused by an OOM error") + raise e else: logger.info( "Worker crashed during call to train(). To attempt to "