From 4a5d07a902cb3bd734c19a632640cc3152c43863 Mon Sep 17 00:00:00 2001 From: Kiko Aumond Date: Tue, 24 Jan 2023 13:16:14 -0800 Subject: [PATCH 1/4] re-raise RayOutOfMemoryErrors; do not bubble up StopIterations --- python/ray/util/iter.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 1a277ccf4945..1a18a8dfbbc8 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -8,6 +8,7 @@ from typing import TypeVar, Generic, Iterable, List, Callable, Any, Iterator import ray +from ray.memory_monitor import RayOutOfMemoryError from ray.util.iter_metrics import MetricsContext, SharedMetrics logger = logging.getLogger(__name__) @@ -525,6 +526,9 @@ def base_iterator(timeout=None): yield _NextValueNotReady() except TimeoutError: yield _NextValueNotReady() + # Propagate OOM exceptions up the stack + except RayOutOfMemoryError: + raise except (StopIteration, RuntimeError) as ex: if was_cause_by_stop_iteration(ex): # If we are streaming (infinite sequence) then @@ -543,14 +547,19 @@ def base_iterator(timeout=None): else: active.remove(a) else: + # ex_is is never a StopIteration since was_cause_by_stop_iteration + # in the if part of the clause will catch all StopIterations raise ex_i if results: yield results elif self.is_infinite_sequence and len(stoped_actors) == len(active): - raise ex + if not isinstance(ex, StopIteration): + raise ex futures = [a.par_iter_next.remote() for a in active] else: + # ex is never a StopIteration since was_cause_by_stop_iteration + # in the if part of the clause will catch all StopIterations raise ex name = f"{self}.batch_across_shards()" From d3753697b1c877e6c6c6bca5364e8ab09e836f34 Mon Sep 17 00:00:00 2001 From: Kiko Aumond Date: Tue, 24 Jan 2023 13:58:02 -0800 Subject: [PATCH 2/4] added RayOutOfMemoryErrors as conditions for terminating workers in trainer.py --- rllib/agents/trainer.py | 43 ++++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 2a39ea93a46a..ccc11c7d03c7 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -431,6 +431,26 @@ # yapf: enable +def is_memory_error(e: Exception) -> bool: + """Check if an exception occurred due to a process running out of memory.""" + memory_error_names = [ + "ray.memory_monitor.RayOutOfMemoryError", + "RayOutOfMemoryError", + ] + ename = type(e).__name__ + + if ename in memory_error_names: + return True + + msg_list = list(filter(lambda s: len(s) > 0, str(e).split("\n"))) + + if ename.startswith("RayTaskError"): + return any( + any(ename in msg for msg in msg_list) for ename in memory_error_names + ) + return False + + @DeveloperAPI def with_common_config( extra_config: PartialTrainerConfigDict) -> TrainerConfigDict: @@ -1217,13 +1237,22 @@ def _try_recover(self): try: ray.get(obj_ref) healthy_workers.append(w) - logger.info("Worker {} looks healthy".format(i + 1)) - except RayError: - logger.exception("Removing unhealthy worker {}".format(i + 1)) - try: - w.__ray_terminate__.remote() - except Exception: - logger.exception("Error terminating unhealthy worker") + logger.info(f"Worker {i + 1} looks healthy") + except Exception as e: + if is_memory_error(e): + logger.exception(f"Removing unhealthy worker {i + 1} due to an OOM error") + terminate = True + elif isinstance(e, RayError): + logger.exception(f"Removing unhealthy worker {i + 1}") + terminate = True + else: + terminate = False + + if terminate: + try: + w.__ray_terminate__.remote() + except Exception: + logger.exception(f"Error terminating unhealthy worker {i + 1}") if len(healthy_workers) < 1: raise RuntimeError( From 3fd77c190e74321d682b96e14044dcf872549723 Mon Sep 17 00:00:00 2001 From: Kiko Aumond Date: Tue, 24 Jan 2023 14:16:50 -0800 Subject: [PATCH 3/4] do not attempt to restart workers when OOM errors occur --- rllib/agents/trainer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index ccc11c7d03c7..19df66cc89be 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -574,11 +574,18 @@ def train(self) -> ResultDict: for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): try: result = Trainable.train(self) - except RayError as e: + except Exception as e: if self.config["ignore_worker_failures"]: - logger.exception( - "Error in train call, attempting to recover") - self._try_recover() + # do not retry in case of OOM errors + if issubclass(e, RayError): + if not is_memory_error(e): + logger.exception( + "Error in train call, attempting to recover") + self._try_recover() + else: + logger.exception("Not attempting to recover from error in train call " + "since it was caused by an OOM error") + raise e else: logger.info( "Worker crashed during call to train(). To attempt to " From c73e0a1a11572a3be6ebe9499ae168f4ba11569a Mon Sep 17 00:00:00 2001 From: Kiko Aumond Date: Tue, 24 Jan 2023 14:19:52 -0800 Subject: [PATCH 4/4] do not attempt to restart workers when OOM errors occur --- rllib/agents/trainer.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 19df66cc89be..995d009f9ec7 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -1244,22 +1244,13 @@ def _try_recover(self): try: ray.get(obj_ref) healthy_workers.append(w) - logger.info(f"Worker {i + 1} looks healthy") - except Exception as e: - if is_memory_error(e): - logger.exception(f"Removing unhealthy worker {i + 1} due to an OOM error") - terminate = True - elif isinstance(e, RayError): - logger.exception(f"Removing unhealthy worker {i + 1}") - terminate = True - else: - terminate = False - - if terminate: - try: - w.__ray_terminate__.remote() - except Exception: - logger.exception(f"Error terminating unhealthy worker {i + 1}") + logger.info("Worker {} looks healthy".format(i + 1)) + except RayError: + logger.exception("Removing unhealthy worker {}".format(i + 1)) + try: + w.__ray_terminate__.remote() + except Exception: + logger.exception("Error terminating unhealthy worker") if len(healthy_workers) < 1: raise RuntimeError(