8
8
9
9
from __future__ import annotations
10
10
11
+ import collections
11
12
import contextlib
12
13
import enum
13
14
import os
14
15
import sys
15
16
import time
16
17
from typing import Any
17
18
from typing import Generator
19
+ from typing import Iterable
18
20
from typing import Literal
19
21
from typing import Sequence
20
22
from typing import TypedDict
23
+ from typing import Union
21
24
import warnings
22
25
23
26
from _pytest .config import _prepareconfig
@@ -66,7 +69,44 @@ def worker_title(title: str) -> None:
66
69
67
70
class Marker (enum .Enum ):
68
71
SHUTDOWN = 0
69
- QUEUE_REPLACED = 1
72
+
73
+
74
+ class TestQueue :
75
+ """A simple queue that can be inspected and modified while the lock is held via the ``lock()`` method."""
76
+
77
+ Item = Union [int , Literal [Marker .SHUTDOWN ]]
78
+
79
+ def __init__ (self , execmodel : execnet .gateway_base .ExecModel ):
80
+ self ._items : collections .deque [TestQueue .Item ] = collections .deque ()
81
+ self ._lock = execmodel .RLock () # type: ignore[no-untyped-call]
82
+ self ._has_items_event = execmodel .Event ()
83
+
84
+ def get (self ) -> Item :
85
+ while True :
86
+ with self .lock () as locked_items :
87
+ if locked_items :
88
+ return locked_items .popleft ()
89
+
90
+ self ._has_items_event .wait ()
91
+
92
+ def put (self , item : Item ) -> None :
93
+ with self .lock () as locked_items :
94
+ locked_items .append (item )
95
+
96
+ def replace (self , iterable : Iterable [Item ]) -> None :
97
+ with self .lock ():
98
+ self ._items = collections .deque (iterable )
99
+
100
+ @contextlib .contextmanager
101
+ def lock (self ) -> Generator [collections .deque [Item ], None , None ]:
102
+ with self ._lock :
103
+ try :
104
+ yield self ._items
105
+ finally :
106
+ if self ._items :
107
+ self ._has_items_event .set ()
108
+ else :
109
+ self ._has_items_event .clear ()
70
110
71
111
72
112
class WorkerInteractor :
@@ -77,22 +117,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
77
117
self .testrunuid = workerinput ["testrunuid" ]
78
118
self .log = Producer (f"worker-{ self .workerid } " , enabled = config .option .debug )
79
119
self .channel = channel
80
- self .torun = self ._make_queue ( )
120
+ self .torun = TestQueue ( self .channel . gateway . execmodel )
81
121
self .nextitem_index : int | None | Literal [Marker .SHUTDOWN ] = None
82
122
config .pluginmanager .register (self )
83
123
84
- def _make_queue (self ) -> Any :
85
- return self .channel .gateway .execmodel .queue .Queue ()
86
-
87
- def _get_next_item_index (self ) -> int | Literal [Marker .SHUTDOWN ]:
88
- """Gets the next item from test queue. Handles the case when the queue
89
- is replaced concurrently in another thread.
90
- """
91
- result = self .torun .get ()
92
- while result is Marker .QUEUE_REPLACED :
93
- result = self .torun .get ()
94
- return result # type: ignore[no-any-return]
95
-
96
124
def sendevent (self , name : str , ** kwargs : object ) -> None :
97
125
self .log ("sending" , name , kwargs )
98
126
self .channel .send ((name , kwargs ))
@@ -146,30 +174,34 @@ def handle_command(
146
174
self .steal (kwargs ["indices" ])
147
175
148
176
def steal (self , indices : Sequence [int ]) -> None :
149
- indices_set = set ( indices )
150
- stolen = []
177
+ """
178
+ Remove tests from the queue.
151
179
152
- old_queue , self .torun = self .torun , self ._make_queue ()
180
+ Removes either all requested tests, or none, if some of these tests
181
+ are not in the queue (for example, if they were processed already).
153
182
154
- def old_queue_get_nowait_noraise () -> int | None :
155
- with contextlib .suppress (self .channel .gateway .execmodel .queue .Empty ):
156
- return old_queue .get_nowait () # type: ignore[no-any-return]
157
- return None
183
+ :param indices: indices of the tests to remove.
184
+ """
185
+ requested_set = set (indices )
186
+
187
+ with self .torun .lock () as locked_queue :
188
+ stolen = list (item for item in locked_queue if item in requested_set )
158
189
159
- for i in iter (old_queue_get_nowait_noraise , None ):
160
- if i in indices_set :
161
- stolen .append (i )
190
+ # Stealing only if all requested tests are still pending
191
+ if len (stolen ) == len (requested_set ):
192
+ self .torun .replace (
193
+ item for item in locked_queue if item not in requested_set
194
+ )
162
195
else :
163
- self . torun . put ( i )
196
+ stolen = []
164
197
165
198
self .sendevent ("unscheduled" , indices = stolen )
166
- old_queue .put (Marker .QUEUE_REPLACED )
167
199
168
200
@pytest .hookimpl
169
201
def pytest_runtestloop (self , session : pytest .Session ) -> bool :
170
202
self .log ("entering main loop" )
171
203
self .channel .setcallback (self .handle_command , endmarker = Marker .SHUTDOWN )
172
- self .nextitem_index = self ._get_next_item_index ()
204
+ self .nextitem_index = self .torun . get ()
173
205
while self .nextitem_index is not Marker .SHUTDOWN :
174
206
self .run_one_test ()
175
207
if session .shouldfail or session .shouldstop :
@@ -179,7 +211,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool:
179
211
def run_one_test (self ) -> None :
180
212
assert isinstance (self .nextitem_index , int )
181
213
self .item_index = self .nextitem_index
182
- self .nextitem_index = self ._get_next_item_index ()
214
+ self .nextitem_index = self .torun . get ()
183
215
184
216
items = self .session .items
185
217
item = items [self .item_index ]
0 commit comments