Skip to content

Terminating Process Upon Cancel #41

@cwitkowitz

Description

@cwitkowitz

Currently, we do not enforce any sort of hard stop mechanism in process_fn, meaning that cancel does not actually terminate execution of this function. On the Gradio (and HARP) frontend, this is fine, as the stale process is simply ignored. However, this means that cancellation zombifies the process until it finishes. This hasn't been too much of an issue so far, but there could definitely be cases where a user exceeds our maximum processing thread limit in HARP (10 at the moment) by continually canceling a slow or stalled model.

ChatGPT recommends the following thread-based wrapping around process_fn:

import multiprocessing as mp
import traceback
import time
from typing import Callable, Any, Tuple

mp.set_start_method("spawn", force=True) # Necessary for Spaces / Linux

def _worker_entry(fn: Callable, args: Tuple[Any, ...], result_q: mp.Queue):
    try:
        result = fn(*args)
        result_q.put(("ok", result))
    except Exception:
        result_q.put(("err", traceback.format_exc()))

class JobSupervisor:
    def __init__(self, timeout_s: float):
        self.timeout_s = timeout_s
        self._process: mp.Process | None = None
        self._result_q: mp.Queue | None = None
        self._start_time: float | None = None

    def run(self, fn: Callable, *args):
        # Enforce single-flight
        self.cancel()

        self._result_q = mp.Queue()
        self._process = mp.Process(
            target=_worker_entry,
            args=(fn, args, self._result_q),
            daemon=True,
        )

        self._start_time = time.monotonic()
        self._process.start()
        self._process.join(self.timeout_s)

        if self._process.is_alive():
            self._terminate("timeout")

        status, payload = self._result_q.get()
        self._cleanup()

        if status == "err":
            raise RuntimeError(payload)

        return payload

    def cancel(self):
        if self._process and self._process.is_alive():
            self._terminate("cancelled")

    def _terminate(self, reason: str):
        self._process.terminate()
        self._process.join()
        self._cleanup()
        raise RuntimeError(f"Job {reason}")

    def _cleanup(self):
        self._process = None
        self._result_q = None
        self._start_time = None

In build_endpoint:

supervisor = JobSupervisor(timeout_s=300)

def supervised_process(*args):
    return supervisor.run(process_fn, *args)

process_event = process_button.click(
    fn=supervised_process,
    inputs=input_components,
    outputs=output_components,
    api_name="process",
)

cancel_button.click(
    fn=supervisor.cancel,
    inputs=[],
    outputs=[],
    api_name="cancel",
    cancels=[process_event],
)

With a solution like this, we could ensure that cancel actually terminates the process immediately, and get away with a single processing thread.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions