5
5
import asyncio
6
6
import contextlib
7
7
import logging
8
- from abc import ABC
8
+ import time
9
+ from abc import ABC , abstractmethod
9
10
from collections .abc import Sequence
10
- from typing import TYPE_CHECKING , TypeVar
11
+ from typing import TYPE_CHECKING , Any , TypeVar
11
12
from zipfile import ZipFile
12
13
13
14
from questionpy_common .api .attempt import AttemptModel , AttemptScoredModel , AttemptStartedModel
14
15
from questionpy_common .constants import DIST_DIR
15
16
from questionpy_common .elements import OptionsFormDefinition
16
- from questionpy_common .environment import RequestUser , WorkerResourceLimits
17
+ from questionpy_common .environment import RequestUser
17
18
from questionpy_common .manifest import Manifest , PackageFile
18
19
from questionpy_server .models import QuestionCreated
19
20
from questionpy_server .utils .manifest import ComparableManifest
20
21
from questionpy_server .worker import PackageFileData , Worker , WorkerState
21
- from questionpy_server .worker .exception import StaticFileSizeMismatchError , WorkerNotRunningError , WorkerStartError
22
+ from questionpy_server .worker .exception import (
23
+ StaticFileSizeMismatchError ,
24
+ WorkerCPUTimeLimitExceededError ,
25
+ WorkerNotRunningError ,
26
+ WorkerRealTimeLimitExceededError ,
27
+ WorkerStartError ,
28
+ )
22
29
from questionpy_server .worker .runtime .messages import (
30
+ BaseWorkerError ,
23
31
CreateQuestionFromOptions ,
24
32
Exit ,
25
33
GetOptionsForm ,
37
45
from questionpy_server .worker .runtime .package_location import (
38
46
DirPackageLocation ,
39
47
FunctionPackageLocation ,
40
- PackageLocation ,
41
48
ZipPackageLocation ,
42
49
)
43
50
@@ -64,14 +71,17 @@ class BaseWorker(Worker, ABC):
64
71
"""Base class implementing some common functionality of workers."""
65
72
66
73
_worker_type = "unknown"
74
+ _init_worker_timeout = 2
75
+ _load_qpy_package_timeout = 4
67
76
68
- def __init__ (self , package : PackageLocation , limits : WorkerResourceLimits | None ) -> None :
69
- super ().__init__ (package , limits )
77
+ def __init__ (self , ** kwargs : Any ) -> None :
78
+ super ().__init__ (** kwargs )
70
79
71
80
self ._observe_task : asyncio .Task | None = None
72
81
73
82
self ._connection : ServerToWorkerConnection | None = None
74
83
self ._expected_incoming_messages : list [tuple [MessageIds , asyncio .Future ]] = []
84
+ self ._receive_messages_exception : BaseException | None = None
75
85
76
86
async def _initialize (self ) -> None :
77
87
"""Initializes an already running worker and starts the observe task.
@@ -88,11 +98,14 @@ async def _initialize(self) -> None:
88
98
worker_type = self ._worker_type ,
89
99
),
90
100
InitWorker .Response ,
101
+ self ._init_worker_timeout ,
91
102
)
92
103
await self .send_and_wait_for_response (
93
- LoadQPyPackage (location = self .package , main = True ), LoadQPyPackage .Response
104
+ LoadQPyPackage (location = self .package , main = True ),
105
+ LoadQPyPackage .Response ,
106
+ self ._load_qpy_package_timeout ,
94
107
)
95
- except WorkerNotRunningError as e :
108
+ except BaseWorkerError as e :
96
109
msg = "Worker has exited before or during initialization."
97
110
raise WorkerStartError (msg ) from e
98
111
@@ -101,7 +114,9 @@ def send(self, message: MessageToWorker) -> None:
101
114
raise WorkerNotRunningError
102
115
self ._connection .send_message (message )
103
116
104
- async def send_and_wait_for_response (self , message : MessageToWorker , expected_response_message : type [_M ]) -> _M :
117
+ async def send_and_wait_for_response (
118
+ self , message : MessageToWorker , expected_response_message : type [_M ], timeout : float | None = None
119
+ ) -> _M :
105
120
self .send (message )
106
121
fut = asyncio .get_running_loop ().create_future ()
107
122
self ._expected_incoming_messages .append ((expected_response_message .message_id , fut ))
@@ -138,7 +153,8 @@ async def _receive_messages(self) -> None:
138
153
finally :
139
154
for _ , future in self ._expected_incoming_messages :
140
155
if not future .done ():
141
- future .set_exception (WorkerNotRunningError ())
156
+ exc = self ._receive_messages_exception or WorkerNotRunningError ()
157
+ future .set_exception (exc )
142
158
self ._expected_incoming_messages = []
143
159
144
160
def _get_observation_tasks (self ) -> Sequence [asyncio .Task ]:
@@ -153,10 +169,14 @@ def _get_observation_tasks(self) -> Sequence[asyncio.Task]:
153
169
154
170
async def _observe (self ) -> None :
155
171
"""Observes the tasks returned by _get_observation_tasks."""
156
- pending : Sequence [asyncio .Task ] = [ ]
172
+ pending : set [asyncio .Task ]
157
173
try :
158
174
tasks = self ._get_observation_tasks ()
159
- _ , pending = await asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
175
+ done , pending = await asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
176
+ for task in done :
177
+ with contextlib .suppress (asyncio .CancelledError ):
178
+ if exc := task .exception ():
179
+ self ._receive_messages_exception = exc
160
180
finally :
161
181
self .state = WorkerState .NOT_RUNNING
162
182
@@ -170,7 +190,7 @@ async def _observe(self) -> None:
170
190
async def stop (self , timeout : float ) -> None :
171
191
try :
172
192
self .send (Exit ())
173
- except WorkerNotRunningError :
193
+ except BaseWorkerError :
174
194
# No need to stop it then.
175
195
return
176
196
@@ -292,3 +312,71 @@ async def get_static_file(self, path: str) -> PackageFileData:
292
312
293
313
async def get_static_file_index (self ) -> dict [str , PackageFile ]:
294
314
return (await self .get_manifest ()).static_files
315
+
316
+
317
+ class LimitTimeUsageMixin (Worker , ABC ):
318
+ """Implements a CPU and real time usage limit for a worker.
319
+
320
+ _limit_cpu_time_usage needs to be added to the return value of :meth:`BaseWorker._get_observation_tasks`.
321
+ The worker will be killed when the CPU time limit is exceeded or the worker took more than three times
322
+ the cpu limit in real time.
323
+ """
324
+
325
+ _real_time_limit_factor = 3
326
+
327
+ def __init__ (self , ** kwargs : Any ) -> None :
328
+ super ().__init__ (** kwargs )
329
+ self ._cur_cpu_time_limit : float = 0
330
+ self ._request_started_cpu_time : float | None = None
331
+ self ._request_started_time : float | None = None
332
+ self ._request_started_event = asyncio .Event ()
333
+
334
+ @abstractmethod
335
+ def _get_cpu_time (self ) -> float :
336
+ """Get worker's current CPU time (user and system time).
337
+
338
+ Returns:
339
+ CPU time in seconds
340
+ """
341
+
342
+ def _set_time_limit (self , limit : float ) -> None :
343
+ """Set a CPU and real time limit.
344
+
345
+ The real time limit is the CPU time limit * :meth:`LimitTimeUsageMixin._real_time_limit_factor`.
346
+
347
+ Args:
348
+ limit: CPU time limit in seconds
349
+ """
350
+ self ._cur_cpu_time_limit = limit
351
+ self ._request_started_cpu_time = self ._get_cpu_time ()
352
+ self ._request_started_time = time .time ()
353
+ self ._request_started_event .set ()
354
+
355
+ def _reset_time_limit (self ) -> None :
356
+ self ._cur_cpu_time_limit = 0
357
+ self ._request_started_cpu_time = None
358
+ self ._request_started_time = None
359
+ self ._request_started_event .clear ()
360
+
361
+ async def _limit_cpu_time_usage (self ) -> None :
362
+ """Ensures that the worker will be killed when it is taking too much time. Executed as a task."""
363
+ while True :
364
+ await self ._request_started_event .wait ()
365
+
366
+ # CPU-time is always less or equal to real time (when single-threaded).
367
+ await asyncio .sleep (self ._cur_cpu_time_limit )
368
+
369
+ # Check if the start time is still set. Probably the request was already processed or
370
+ # maybe another request started meanwhile.
371
+ while self ._request_started_cpu_time is not None and self ._request_started_time is not None :
372
+ remaining_cpu_time = self ._request_started_cpu_time + self ._cur_cpu_time_limit - self ._get_cpu_time ()
373
+ if remaining_cpu_time <= 0 :
374
+ raise WorkerCPUTimeLimitExceededError (self ._cur_cpu_time_limit )
375
+
376
+ remaining_time = (
377
+ self ._request_started_time + (self ._cur_cpu_time_limit * self ._real_time_limit_factor ) - time .time ()
378
+ )
379
+ if remaining_time <= 0 :
380
+ raise WorkerRealTimeLimitExceededError (self ._cur_cpu_time_limit * self ._real_time_limit_factor )
381
+
382
+ await asyncio .sleep (max (min (remaining_cpu_time , remaining_time ), 0.05 ))
0 commit comments