Skip to content

Commit 5af1c7c

Browse files
committed
fix: move time limit logic to a mixin and respect real time
1 parent e2c4b1f commit 5af1c7c

File tree

8 files changed

+199
-79
lines changed

8 files changed

+199
-79
lines changed

questionpy_server/worker/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class Worker(ABC):
5454
"""Interface for worker implementations."""
5555

5656
def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None) -> None:
57+
super().__init__()
5758
self.package = package
5859
self.limits = limits
5960
self.state = WorkerState.NOT_RUNNING

questionpy_server/worker/exception.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
# This file is part of the QuestionPy Server. (https://questionpy.org)
22
# The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md.
33
# (c) Technische Universität Berlin, innoCampus <[email protected]>
4+
from questionpy_server.worker.runtime.messages import BaseWorkerError
45

56

6-
class WorkerNotRunningError(Exception):
7+
class WorkerNotRunningError(BaseWorkerError):
78
pass
89

910

10-
class WorkerStartError(Exception):
11+
class WorkerStartError(BaseWorkerError):
1112
pass
1213

1314

14-
class WorkerCPUTimeLimitExceededError(Exception):
15-
pass
15+
class WorkerCPUTimeLimitExceededError(BaseWorkerError):
16+
def __init__(self, limit: float):
17+
self.limit = limit
18+
super().__init__(f"Worker has exceeded its CPU time limit of {limit} seconds and was killed.")
19+
20+
21+
class WorkerRealTimeLimitExceededError(BaseWorkerError):
22+
def __init__(self, limit: float):
23+
self.limit = limit
24+
super().__init__(f"Worker has exceeded its real time limit of {limit} seconds and was killed.")
1625

1726

1827
class StaticFileSizeMismatchError(Exception):

questionpy_server/worker/impl/_base.py

Lines changed: 102 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,29 @@
55
import asyncio
66
import contextlib
77
import logging
8-
from abc import ABC
8+
import time
9+
from abc import ABC, abstractmethod
910
from collections.abc import Sequence
10-
from typing import TYPE_CHECKING, TypeVar
11+
from typing import TYPE_CHECKING, Any, TypeVar
1112
from zipfile import ZipFile
1213

1314
from questionpy_common.api.attempt import AttemptModel, AttemptScoredModel, AttemptStartedModel
1415
from questionpy_common.constants import DIST_DIR
1516
from questionpy_common.elements import OptionsFormDefinition
16-
from questionpy_common.environment import RequestUser, WorkerResourceLimits
17+
from questionpy_common.environment import RequestUser
1718
from questionpy_common.manifest import Manifest, PackageFile
1819
from questionpy_server.models import QuestionCreated
1920
from questionpy_server.utils.manifest import ComparableManifest
2021
from questionpy_server.worker import PackageFileData, Worker, WorkerState
21-
from questionpy_server.worker.exception import StaticFileSizeMismatchError, WorkerNotRunningError, WorkerStartError
22+
from questionpy_server.worker.exception import (
23+
StaticFileSizeMismatchError,
24+
WorkerCPUTimeLimitExceededError,
25+
WorkerNotRunningError,
26+
WorkerRealTimeLimitExceededError,
27+
WorkerStartError,
28+
)
2229
from questionpy_server.worker.runtime.messages import (
30+
BaseWorkerError,
2331
CreateQuestionFromOptions,
2432
Exit,
2533
GetOptionsForm,
@@ -37,7 +45,6 @@
3745
from questionpy_server.worker.runtime.package_location import (
3846
DirPackageLocation,
3947
FunctionPackageLocation,
40-
PackageLocation,
4148
ZipPackageLocation,
4249
)
4350

@@ -64,14 +71,17 @@ class BaseWorker(Worker, ABC):
6471
"""Base class implementing some common functionality of workers."""
6572

6673
_worker_type = "unknown"
74+
_init_worker_timeout = 2
75+
_load_qpy_package_timeout = 4
6776

68-
def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None) -> None:
69-
super().__init__(package, limits)
77+
def __init__(self, **kwargs: Any) -> None:
78+
super().__init__(**kwargs)
7079

7180
self._observe_task: asyncio.Task | None = None
7281

7382
self._connection: ServerToWorkerConnection | None = None
7483
self._expected_incoming_messages: list[tuple[MessageIds, asyncio.Future]] = []
84+
self._receive_messages_exception: BaseException | None = None
7585

7686
async def _initialize(self) -> None:
7787
"""Initializes an already running worker and starts the observe task.
@@ -88,11 +98,14 @@ async def _initialize(self) -> None:
8898
worker_type=self._worker_type,
8999
),
90100
InitWorker.Response,
101+
self._init_worker_timeout,
91102
)
92103
await self.send_and_wait_for_response(
93-
LoadQPyPackage(location=self.package, main=True), LoadQPyPackage.Response
104+
LoadQPyPackage(location=self.package, main=True),
105+
LoadQPyPackage.Response,
106+
self._load_qpy_package_timeout,
94107
)
95-
except WorkerNotRunningError as e:
108+
except BaseWorkerError as e:
96109
msg = "Worker has exited before or during initialization."
97110
raise WorkerStartError(msg) from e
98111

@@ -101,7 +114,9 @@ def send(self, message: MessageToWorker) -> None:
101114
raise WorkerNotRunningError
102115
self._connection.send_message(message)
103116

104-
async def send_and_wait_for_response(self, message: MessageToWorker, expected_response_message: type[_M]) -> _M:
117+
async def send_and_wait_for_response(
118+
self, message: MessageToWorker, expected_response_message: type[_M], timeout: float | None = None
119+
) -> _M:
105120
self.send(message)
106121
fut = asyncio.get_running_loop().create_future()
107122
self._expected_incoming_messages.append((expected_response_message.message_id, fut))
@@ -138,7 +153,8 @@ async def _receive_messages(self) -> None:
138153
finally:
139154
for _, future in self._expected_incoming_messages:
140155
if not future.done():
141-
future.set_exception(WorkerNotRunningError())
156+
exc = self._receive_messages_exception or WorkerNotRunningError()
157+
future.set_exception(exc)
142158
self._expected_incoming_messages = []
143159

144160
def _get_observation_tasks(self) -> Sequence[asyncio.Task]:
@@ -153,10 +169,14 @@ def _get_observation_tasks(self) -> Sequence[asyncio.Task]:
153169

154170
async def _observe(self) -> None:
155171
"""Observes the tasks returned by _get_observation_tasks."""
156-
pending: Sequence[asyncio.Task] = []
172+
pending: set[asyncio.Task]
157173
try:
158174
tasks = self._get_observation_tasks()
159-
_, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
175+
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
176+
for task in done:
177+
with contextlib.suppress(asyncio.CancelledError):
178+
if exc := task.exception():
179+
self._receive_messages_exception = exc
160180
finally:
161181
self.state = WorkerState.NOT_RUNNING
162182

@@ -170,7 +190,7 @@ async def _observe(self) -> None:
170190
async def stop(self, timeout: float) -> None:
171191
try:
172192
self.send(Exit())
173-
except WorkerNotRunningError:
193+
except BaseWorkerError:
174194
# No need to stop it then.
175195
return
176196

@@ -292,3 +312,71 @@ async def get_static_file(self, path: str) -> PackageFileData:
292312

293313
async def get_static_file_index(self) -> dict[str, PackageFile]:
294314
return (await self.get_manifest()).static_files
315+
316+
317+
class LimitTimeUsageMixin(Worker, ABC):
318+
"""Implements a CPU and real time usage limit for a worker.
319+
320+
_limit_cpu_time_usage needs to be added to the return value of :meth:`BaseWorker._get_observation_tasks`.
321+
The worker will be killed when the CPU time limit is exceeded or the worker took more than three times
322+
the cpu limit in real time.
323+
"""
324+
325+
_real_time_limit_factor = 3
326+
327+
def __init__(self, **kwargs: Any) -> None:
328+
super().__init__(**kwargs)
329+
self._cur_cpu_time_limit: float = 0
330+
self._request_started_cpu_time: float | None = None
331+
self._request_started_time: float | None = None
332+
self._request_started_event = asyncio.Event()
333+
334+
@abstractmethod
335+
def _get_cpu_time(self) -> float:
336+
"""Get worker's current CPU time (user and system time).
337+
338+
Returns:
339+
CPU time in seconds
340+
"""
341+
342+
def _set_time_limit(self, limit: float) -> None:
343+
"""Set a CPU and real time limit.
344+
345+
The real time limit is the CPU time limit * :meth:`LimitTimeUsageMixin._real_time_limit_factor`.
346+
347+
Args:
348+
limit: CPU time limit in seconds
349+
"""
350+
self._cur_cpu_time_limit = limit
351+
self._request_started_cpu_time = self._get_cpu_time()
352+
self._request_started_time = time.time()
353+
self._request_started_event.set()
354+
355+
def _reset_time_limit(self) -> None:
356+
self._cur_cpu_time_limit = 0
357+
self._request_started_cpu_time = None
358+
self._request_started_time = None
359+
self._request_started_event.clear()
360+
361+
async def _limit_cpu_time_usage(self) -> None:
362+
"""Ensures that the worker will be killed when it is taking too much time. Executed as a task."""
363+
while True:
364+
await self._request_started_event.wait()
365+
366+
# CPU-time is always less or equal to real time (when single-threaded).
367+
await asyncio.sleep(self._cur_cpu_time_limit)
368+
369+
# Check if the start time is still set. Probably the request was already processed or
370+
# maybe another request started meanwhile.
371+
while self._request_started_cpu_time is not None and self._request_started_time is not None:
372+
remaining_cpu_time = self._request_started_cpu_time + self._cur_cpu_time_limit - self._get_cpu_time()
373+
if remaining_cpu_time <= 0:
374+
raise WorkerCPUTimeLimitExceededError(self._cur_cpu_time_limit)
375+
376+
remaining_time = (
377+
self._request_started_time + (self._cur_cpu_time_limit * self._real_time_limit_factor) - time.time()
378+
)
379+
if remaining_time <= 0:
380+
raise WorkerRealTimeLimitExceededError(self._cur_cpu_time_limit * self._real_time_limit_factor)
381+
382+
await asyncio.sleep(max(min(remaining_cpu_time, remaining_time), 0.05))

questionpy_server/worker/impl/subprocess.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import logging
7+
import math
78
import sys
89
from asyncio import StreamReader
910
from collections.abc import Sequence
@@ -16,8 +17,8 @@
1617
from questionpy_common.environment import WorkerResourceLimits
1718
from questionpy_server.worker import WorkerResources
1819
from questionpy_server.worker.connection import ServerToWorkerConnection
19-
from questionpy_server.worker.exception import WorkerCPUTimeLimitExceededError, WorkerNotRunningError, WorkerStartError
20-
from questionpy_server.worker.impl._base import BaseWorker
20+
from questionpy_server.worker.exception import WorkerNotRunningError, WorkerStartError
21+
from questionpy_server.worker.impl._base import BaseWorker, LimitTimeUsageMixin
2122
from questionpy_server.worker.runtime.messages import MessageToServer, MessageToWorker
2223
from questionpy_server.worker.runtime.package_location import PackageLocation
2324

@@ -71,7 +72,7 @@ def flush(self) -> None:
7172
self._skipped_bytes = 0
7273

7374

74-
class SubprocessWorker(BaseWorker):
75+
class SubprocessWorker(BaseWorker, LimitTimeUsageMixin):
7576
"""Worker implementation running in a non-sandboxed subprocess."""
7677

7778
_worker_type = "process"
@@ -80,7 +81,7 @@ class SubprocessWorker(BaseWorker):
8081
_runtime_main = ["-m", "questionpy_server.worker.runtime"]
8182

8283
def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None):
83-
super().__init__(package, limits)
84+
super().__init__(package=package, limits=limits)
8485

8586
self._proc: Process | None = None
8687
self._stderr_buffer: _StderrBuffer | None = None
@@ -111,48 +112,16 @@ async def start(self) -> None:
111112
# Whether initialization was successful or not, flush the logs.
112113
self._stderr_buffer.flush()
113114

114-
async def _limit_cpu_time_usage(self, expected_response_message: type[_T]) -> None:
115-
if not self._proc or self._proc.returncode is not None:
116-
raise WorkerNotRunningError
117-
118-
if self.limits is None:
119-
return
120-
121-
psutil_proc = psutil.Process(self._proc.pid)
122-
123-
# Get current cpu times and calculate the maximum cpu time for the current call.
124-
cpu_times = psutil_proc.cpu_times()
125-
max_cpu_time = cpu_times.user + cpu_times.system + self.limits.max_cpu_time_seconds_per_call
126-
127-
# CPU-time is always less or equal to real time.
128-
await asyncio.sleep(self.limits.max_cpu_time_seconds_per_call)
129-
130-
while True:
131-
cpu_times = psutil_proc.cpu_times()
132-
remaining_time = max_cpu_time - (cpu_times.user + cpu_times.system)
133-
if remaining_time <= 0:
134-
break
135-
await asyncio.sleep(max(remaining_time, 0.05))
136-
137-
# Set the exception and kill the process.
138-
for future in [
139-
fut
140-
for expected_id, fut in self._expected_incoming_messages
141-
if expected_id == expected_response_message.message_id
142-
]:
143-
future.set_exception(WorkerCPUTimeLimitExceededError)
144-
self._expected_incoming_messages.remove((expected_response_message.message_id, future))
145-
146-
await self.kill()
147-
148-
async def send_and_wait_for_response(self, message: MessageToWorker, expected_response_message: type[_T]) -> _T:
149-
timeout = asyncio.create_task(
150-
self._limit_cpu_time_usage(expected_response_message), name="limit cpu time usage"
151-
)
115+
async def send_and_wait_for_response(
116+
self, message: MessageToWorker, expected_response_message: type[_T], timeout: float | None = None
117+
) -> _T:
152118
try:
153-
return await super().send_and_wait_for_response(message, expected_response_message)
119+
if timeout is None:
120+
timeout = self.limits.max_cpu_time_seconds_per_call if self.limits else math.inf
121+
self._set_time_limit(timeout)
122+
return await super().send_and_wait_for_response(message, expected_response_message, timeout)
154123
finally:
155-
timeout.cancel()
124+
self._reset_time_limit()
156125
# Write worker's stderr to log after every exchange.
157126
if self._stderr_buffer:
158127
self._stderr_buffer.flush()
@@ -176,8 +145,20 @@ def _get_observation_tasks(self) -> Sequence[asyncio.Task]:
176145
*super()._get_observation_tasks(),
177146
asyncio.create_task(self._proc.wait(), name="wait for worker process"),
178147
asyncio.create_task(self._stderr_buffer.read_stderr(), name="receive stderr from worker"),
148+
asyncio.create_task(self._limit_cpu_time_usage(), name="limit cpu time usage"),
179149
)
180150

181151
async def kill(self) -> None:
182152
if self._proc and self._proc.returncode is None:
183153
self._proc.kill()
154+
155+
# Make sure that all resources of the subprocesses are getting cleaned.
156+
await self._proc.wait()
157+
158+
def _get_cpu_time(self) -> float:
159+
if not self._proc or self._proc.returncode is not None:
160+
raise WorkerNotRunningError
161+
162+
psutil_proc = psutil.Process(self._proc.pid)
163+
cpu_times = psutil_proc.cpu_times()
164+
return cpu_times.user + cpu_times.system

questionpy_server/worker/impl/thread.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class ThreadWorker(BaseWorker):
6565
_worker_type = "thread"
6666

6767
def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None) -> None:
68-
super().__init__(package, limits)
68+
super().__init__(package=package, limits=limits)
6969

7070
self._pipe: DuplexPipe | None = None
7171

questionpy_server/worker/runtime/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Callable, Generator
66
from contextlib import contextmanager
77
from dataclasses import dataclass
8-
from typing import NoReturn, cast, TypeAlias, TypeVar
8+
from typing import NoReturn, TypeAlias, TypeVar, cast
99

1010
from questionpy_common.api.qtype import QuestionTypeInterface
1111
from questionpy_common.environment import (

questionpy_server/worker/runtime/messages.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,13 @@ def __init__(self, message_id: int, length: int):
269269
super().__init__(f"Received unknown message with id {message_id} and length {length}.")
270270

271271

272-
class WorkerMemoryLimitExceededError(Exception):
272+
class BaseWorkerError(Exception):
273273
pass
274274

275275

276-
class WorkerUnknownError(Exception):
276+
class WorkerMemoryLimitExceededError(BaseWorkerError):
277+
pass
278+
279+
280+
class WorkerUnknownError(BaseWorkerError):
277281
pass

0 commit comments

Comments
 (0)