55import asyncio
66import contextlib
77import logging
8- from abc import ABC
8+ import time
9+ from abc import ABC , abstractmethod
910from collections .abc import Sequence
10- from typing import TYPE_CHECKING , TypeVar
11+ from typing import TYPE_CHECKING , Any , TypeVar
1112from zipfile import ZipFile
1213
1314from questionpy_common .api .attempt import AttemptModel , AttemptScoredModel , AttemptStartedModel
1415from questionpy_common .constants import DIST_DIR
1516from questionpy_common .elements import OptionsFormDefinition
16- from questionpy_common .environment import RequestUser , WorkerResourceLimits
17+ from questionpy_common .environment import RequestUser
1718from questionpy_common .manifest import Manifest , PackageFile
1819from questionpy_server .models import QuestionCreated
1920from questionpy_server .utils .manifest import ComparableManifest
2021from questionpy_server .worker import PackageFileData , Worker , WorkerState
21- from questionpy_server .worker .exception import StaticFileSizeMismatchError , WorkerNotRunningError , WorkerStartError
22+ from questionpy_server .worker .exception import (
23+ StaticFileSizeMismatchError ,
24+ WorkerCPUTimeLimitExceededError ,
25+ WorkerNotRunningError ,
26+ WorkerRealTimeLimitExceededError ,
27+ WorkerStartError ,
28+ )
2229from questionpy_server .worker .runtime .messages import (
30+ BaseWorkerError ,
2331 CreateQuestionFromOptions ,
2432 Exit ,
2533 GetOptionsForm ,
3745from questionpy_server .worker .runtime .package_location import (
3846 DirPackageLocation ,
3947 FunctionPackageLocation ,
40- PackageLocation ,
4148 ZipPackageLocation ,
4249)
4350
@@ -64,14 +71,17 @@ class BaseWorker(Worker, ABC):
6471 """Base class implementing some common functionality of workers."""
6572
6673 _worker_type = "unknown"
74+ _init_worker_timeout = 2
75+ _load_qpy_package_timeout = 4
6776
68- def __init__ (self , package : PackageLocation , limits : WorkerResourceLimits | None ) -> None :
69- super ().__init__ (package , limits )
77+ def __init__ (self , ** kwargs : Any ) -> None :
78+ super ().__init__ (** kwargs )
7079
7180 self ._observe_task : asyncio .Task | None = None
7281
7382 self ._connection : ServerToWorkerConnection | None = None
7483 self ._expected_incoming_messages : list [tuple [MessageIds , asyncio .Future ]] = []
84+ self ._receive_messages_exception : BaseException | None = None
7585
7686 async def _initialize (self ) -> None :
7787 """Initializes an already running worker and starts the observe task.
@@ -88,11 +98,14 @@ async def _initialize(self) -> None:
8898 worker_type = self ._worker_type ,
8999 ),
90100 InitWorker .Response ,
101+ self ._init_worker_timeout ,
91102 )
92103 await self .send_and_wait_for_response (
93- LoadQPyPackage (location = self .package , main = True ), LoadQPyPackage .Response
104+ LoadQPyPackage (location = self .package , main = True ),
105+ LoadQPyPackage .Response ,
106+ self ._load_qpy_package_timeout ,
94107 )
95- except WorkerNotRunningError as e :
108+ except BaseWorkerError as e :
96109 msg = "Worker has exited before or during initialization."
97110 raise WorkerStartError (msg ) from e
98111
@@ -101,7 +114,9 @@ def send(self, message: MessageToWorker) -> None:
101114 raise WorkerNotRunningError
102115 self ._connection .send_message (message )
103116
104- async def send_and_wait_for_response (self , message : MessageToWorker , expected_response_message : type [_M ]) -> _M :
117+ async def send_and_wait_for_response (
118+ self , message : MessageToWorker , expected_response_message : type [_M ], timeout : float | None = None
119+ ) -> _M :
105120 self .send (message )
106121 fut = asyncio .get_running_loop ().create_future ()
107122 self ._expected_incoming_messages .append ((expected_response_message .message_id , fut ))
@@ -138,7 +153,8 @@ async def _receive_messages(self) -> None:
138153 finally :
139154 for _ , future in self ._expected_incoming_messages :
140155 if not future .done ():
141- future .set_exception (WorkerNotRunningError ())
156+ exc = self ._receive_messages_exception or WorkerNotRunningError ()
157+ future .set_exception (exc )
142158 self ._expected_incoming_messages = []
143159
144160 def _get_observation_tasks (self ) -> Sequence [asyncio .Task ]:
@@ -153,10 +169,14 @@ def _get_observation_tasks(self) -> Sequence[asyncio.Task]:
153169
154170 async def _observe (self ) -> None :
155171 """Observes the tasks returned by _get_observation_tasks."""
156- pending : Sequence [asyncio .Task ] = [ ]
172+ pending : set [asyncio .Task ]
157173 try :
158174 tasks = self ._get_observation_tasks ()
159- _ , pending = await asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
175+ done , pending = await asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
176+ for task in done :
177+ with contextlib .suppress (asyncio .CancelledError ):
178+ if exc := task .exception ():
179+ self ._receive_messages_exception = exc
160180 finally :
161181 self .state = WorkerState .NOT_RUNNING
162182
@@ -170,7 +190,7 @@ async def _observe(self) -> None:
170190 async def stop (self , timeout : float ) -> None :
171191 try :
172192 self .send (Exit ())
173- except WorkerNotRunningError :
193+ except BaseWorkerError :
174194 # No need to stop it then.
175195 return
176196
@@ -292,3 +312,71 @@ async def get_static_file(self, path: str) -> PackageFileData:
292312
293313 async def get_static_file_index (self ) -> dict [str , PackageFile ]:
294314 return (await self .get_manifest ()).static_files
315+
316+
317+ class LimitTimeUsageMixin (Worker , ABC ):
318+ """Implements a CPU and real time usage limit for a worker.
319+
320+ _limit_cpu_time_usage needs to be added to the return value of :meth:`BaseWorker._get_observation_tasks`.
321+ The worker will be killed when the CPU time limit is exceeded or the worker took more than three times
322+ the cpu limit in real time.
323+ """
324+
325+ _real_time_limit_factor = 3
326+
327+ def __init__ (self , ** kwargs : Any ) -> None :
328+ super ().__init__ (** kwargs )
329+ self ._cur_cpu_time_limit : float = 0
330+ self ._request_started_cpu_time : float | None = None
331+ self ._request_started_time : float | None = None
332+ self ._request_started_event = asyncio .Event ()
333+
334+ @abstractmethod
335+ def _get_cpu_time (self ) -> float :
336+ """Get worker's current CPU time (user and system time).
337+
338+ Returns:
339+ CPU time in seconds
340+ """
341+
342+ def _set_time_limit (self , limit : float ) -> None :
343+ """Set a CPU and real time limit.
344+
345+ The real time limit is the CPU time limit * :meth:`LimitTimeUsageMixin._real_time_limit_factor`.
346+
347+ Args:
348+ limit: CPU time limit in seconds
349+ """
350+ self ._cur_cpu_time_limit = limit
351+ self ._request_started_cpu_time = self ._get_cpu_time ()
352+ self ._request_started_time = time .time ()
353+ self ._request_started_event .set ()
354+
355+ def _reset_time_limit (self ) -> None :
356+ self ._cur_cpu_time_limit = 0
357+ self ._request_started_cpu_time = None
358+ self ._request_started_time = None
359+ self ._request_started_event .clear ()
360+
361+ async def _limit_cpu_time_usage (self ) -> None :
362+ """Ensures that the worker will be killed when it is taking too much time. Executed as a task."""
363+ while True :
364+ await self ._request_started_event .wait ()
365+
366+ # CPU-time is always less or equal to real time (when single-threaded).
367+ await asyncio .sleep (self ._cur_cpu_time_limit )
368+
369+ # Check if the start time is still set. Probably the request was already processed or
370+ # maybe another request started meanwhile.
371+ while self ._request_started_cpu_time is not None and self ._request_started_time is not None :
372+ remaining_cpu_time = self ._request_started_cpu_time + self ._cur_cpu_time_limit - self ._get_cpu_time ()
373+ if remaining_cpu_time <= 0 :
374+ raise WorkerCPUTimeLimitExceededError (self ._cur_cpu_time_limit )
375+
376+ remaining_time = (
377+ self ._request_started_time + (self ._cur_cpu_time_limit * self ._real_time_limit_factor ) - time .time ()
378+ )
379+ if remaining_time <= 0 :
380+ raise WorkerRealTimeLimitExceededError (self ._cur_cpu_time_limit * self ._real_time_limit_factor )
381+
382+ await asyncio .sleep (max (min (remaining_cpu_time , remaining_time ), 0.05 ))
0 commit comments