Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/ray/util/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TypeVar, Generic, Iterable, List, Callable, Any, Iterator

import ray
from ray.memory_monitor import RayOutOfMemoryError
from ray.util.iter_metrics import MetricsContext, SharedMetrics

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -525,6 +526,9 @@ def base_iterator(timeout=None):
yield _NextValueNotReady()
except TimeoutError:
yield _NextValueNotReady()
# Propagate OOM exceptions up the stack
except RayOutOfMemoryError:
raise
except (StopIteration, RuntimeError) as ex:
if was_cause_by_stop_iteration(ex):
# If we are streaming (infinite sequence) then
Expand All @@ -543,14 +547,19 @@ def base_iterator(timeout=None):
else:
active.remove(a)
else:
# ex_is is never a StopIteration since was_cause_by_stop_iteration
# in the if part of the clause will catch all StopIterations
raise ex_i

if results:
yield results
elif self.is_infinite_sequence and len(stoped_actors) == len(active):
raise ex
if not isinstance(ex, StopIteration):
raise ex
futures = [a.par_iter_next.remote() for a in active]
else:
# ex is never a StopIteration since was_cause_by_stop_iteration
# in the if part of the clause will catch all StopIterations
raise ex

name = f"{self}.batch_across_shards()"
Expand Down
35 changes: 31 additions & 4 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,26 @@
# yapf: enable


def is_memory_error(e: Exception) -> bool:
"""Check if an exception occurred due to a process running out of memory."""
memory_error_names = [
"ray.memory_monitor.RayOutOfMemoryError",
"RayOutOfMemoryError",
]
ename = type(e).__name__

if ename in memory_error_names:
return True

msg_list = list(filter(lambda s: len(s) > 0, str(e).split("\n")))

if ename.startswith("RayTaskError"):
return any(
any(ename in msg for msg in msg_list) for ename in memory_error_names
)
return False


@DeveloperAPI
def with_common_config(
extra_config: PartialTrainerConfigDict) -> TrainerConfigDict:
Expand Down Expand Up @@ -554,11 +574,18 @@ def train(self) -> ResultDict:
for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
try:
result = Trainable.train(self)
except RayError as e:
except Exception as e:
if self.config["ignore_worker_failures"]:
logger.exception(
"Error in train call, attempting to recover")
self._try_recover()
# do not retry in case of OOM errors
if issubclass(e, RayError):
if not is_memory_error(e):
logger.exception(
"Error in train call, attempting to recover")
self._try_recover()
else:
logger.exception("Not attempting to recover from error in train call "
"since it was caused by an OOM error")
raise e
else:
logger.info(
"Worker crashed during call to train(). To attempt to "
Expand Down