Skip to content

Commit

Permalink
ray_tpu: Catch preemption better, don't hang if there's a c level abo…
Browse files Browse the repository at this point in the history
…rt (#877)
  • Loading branch information
dlwh authored Feb 3, 2025
1 parent 1d216d1 commit 59d2138
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tempfile
import time
from dataclasses import dataclass
from queue import Empty as QueueEmpty
from typing import Callable, Optional, Sequence

import draccus
Expand Down Expand Up @@ -89,24 +90,24 @@ def do_run(remote_fn) -> _TpuRunResult:
logger.info("TPU job finished")
return TpuSuccess(info, out)
except RayError as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
_cancel_all_futures(futures)
return _handle_ray_error(info, e)
except Exception as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
_cancel_all_futures(futures)
return TpuFailed(info, e)

return do_run.remote(remote_fn)


def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef:
def _cancel_all_futures(futures):
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")


def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> list[ray.ObjectRef]:
"""
Run a remote function on multiple TPU slices.
Expand Down Expand Up @@ -147,18 +148,12 @@ def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResu
logger.info("TPU job finished")
return TpuSuccess(info, out)
except RayError as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
logger.exception(f"Ray error {e}. Killing futures for this slice")
_cancel_all_futures(futures)
return _handle_ray_error(info, e)
except Exception as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
logger.exception(f"Exception {e}")
_cancel_all_futures(futures)
return TpuFailed(info, e)

actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore
Expand Down Expand Up @@ -310,6 +305,16 @@ def run_on_pod_multislice_resumable(
futures = run_on_pod_multislice(remote_fn, tpu_type, num_slices)
try:
outs = ray.get(futures)
except ray.exceptions.ActorUnavailableError as e:
problem = e
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times, {e}")
continue
except ray.exceptions.ActorDiedError as e:
problem = e
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times, {e}")
continue
except ray.exceptions.RayTaskError as e:
for f in futures:
try:
Expand Down Expand Up @@ -425,6 +430,9 @@ def _handle_ray_error(tpu_info: _TpuInfo, e: RayError):
if isinstance(e, NodeDiedError):
logger.exception("Node died", exc_info=e)
return TpuPreempted(tpu_info, e)
elif isinstance(e, ray.exceptions.ActorUnavailableError | ray.exceptions.ActorDiedError):
logger.exception("Actor died", exc_info=e)
return TpuPreempted(tpu_info, e)
elif isinstance(e, WorkerCrashedError):
logger.exception("Worker crashed", exc_info=e)
return TpuPreempted(tpu_info, e)
Expand Down Expand Up @@ -506,7 +514,13 @@ def target_fn(queue, args, kwargs):
process.join()

# Retrieve the result or error from the queue
success, value = queue.get()
logger.info("Process finished")
try:
success, value = queue.get(timeout=10)
except QueueEmpty:
logger.error("Process timed out")
process.terminate()
raise RuntimeError("Process timed out")

if success:
return value
Expand Down

0 comments on commit 59d2138

Please sign in to comment.