8
8
9
9
from __future__ import annotations
10
10
11
+ import collections
11
12
import contextlib
12
13
import enum
13
14
import os
14
15
import sys
15
16
import time
16
17
from typing import Any
17
18
from typing import Generator
19
+ from typing import Iterable
18
20
from typing import Literal
19
21
from typing import Sequence
20
22
from typing import TypedDict
23
+ from typing import Union
21
24
import warnings
22
25
23
26
from _pytest .config import _prepareconfig
@@ -66,7 +69,44 @@ def worker_title(title: str) -> None:
66
69
67
70
class Marker (enum .Enum ):
68
71
SHUTDOWN = 0
69
- QUEUE_REPLACED = 1
72
+
73
+
74
+ class TestQueue :
75
+ """A simple queue that can be inspected and modified while the lock is held."""
76
+
77
+ Item = Union [int , Literal [Marker .SHUTDOWN ]]
78
+
79
+ def __init__ (self , execmodel : execnet .gateway_base .ExecModel ):
80
+ self ._items : collections .deque [TestQueue .Item ] = collections .deque ()
81
+ self ._lock = execmodel .RLock () # type: ignore[no-untyped-call]
82
+ self ._has_items_event = execmodel .Event ()
83
+
84
+ def get (self ) -> Item :
85
+ while True :
86
+ with self .lock () as locked_items :
87
+ if locked_items :
88
+ return locked_items .popleft ()
89
+
90
+ self ._has_items_event .wait ()
91
+
92
+ def put (self , item : Item ) -> None :
93
+ with self .lock () as locked_items :
94
+ locked_items .append (item )
95
+
96
+ def replace (self , iterable : Iterable [Item ]) -> None :
97
+ with self .lock ():
98
+ self ._items = collections .deque (iterable )
99
+
100
+ @contextlib .contextmanager
101
+ def lock (self ) -> Generator [collections .deque [Item ], None , None ]:
102
+ with self ._lock :
103
+ try :
104
+ yield self ._items
105
+ finally :
106
+ if self ._items :
107
+ self ._has_items_event .set ()
108
+ else :
109
+ self ._has_items_event .clear ()
70
110
71
111
72
112
class WorkerInteractor :
@@ -77,22 +117,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
77
117
self .testrunuid = workerinput ["testrunuid" ]
78
118
self .log = Producer (f"worker-{ self .workerid } " , enabled = config .option .debug )
79
119
self .channel = channel
80
- self .torun = self ._make_queue ( )
120
+ self .torun = TestQueue ( self .channel . gateway . execmodel )
81
121
self .nextitem_index : int | None | Literal [Marker .SHUTDOWN ] = None
82
122
config .pluginmanager .register (self )
83
123
84
- def _make_queue (self ) -> Any :
85
- return self .channel .gateway .execmodel .queue .Queue ()
86
-
87
- def _get_next_item_index (self ) -> int | Literal [Marker .SHUTDOWN ]:
88
- """Gets the next item from test queue. Handles the case when the queue
89
- is replaced concurrently in another thread.
90
- """
91
- result = self .torun .get ()
92
- while result is Marker .QUEUE_REPLACED :
93
- result = self .torun .get ()
94
- return result # type: ignore[no-any-return]
95
-
96
124
def sendevent (self , name : str , ** kwargs : object ) -> None :
97
125
self .log ("sending" , name , kwargs )
98
126
self .channel .send ((name , kwargs ))
@@ -146,30 +174,25 @@ def handle_command(
146
174
self .steal (kwargs ["indices" ])
147
175
148
176
def steal (self , indices : Sequence [int ]) -> None :
149
- indices_set = set (indices )
150
- stolen = []
151
-
152
- old_queue , self .torun = self .torun , self ._make_queue ()
153
-
154
- def old_queue_get_nowait_noraise () -> int | None :
155
- with contextlib .suppress (self .channel .gateway .execmodel .queue .Empty ):
156
- return old_queue .get_nowait () # type: ignore[no-any-return]
157
- return None
158
-
159
- for i in iter (old_queue_get_nowait_noraise , None ):
160
- if i in indices_set :
161
- stolen .append (i )
177
+ with self .torun .lock () as locked_queue :
178
+ requested_set = set (indices )
179
+ stolen = list (item for item in locked_queue if item in requested_set )
180
+
181
+ # Stealing only if all requested tests are still pending
182
+ if len (stolen ) == len (requested_set ):
183
+ self .torun .replace (
184
+ item for item in locked_queue if item not in requested_set
185
+ )
162
186
else :
163
- self . torun . put ( i )
187
+ stolen = []
164
188
165
189
self .sendevent ("unscheduled" , indices = stolen )
166
- old_queue .put (Marker .QUEUE_REPLACED )
167
190
168
191
@pytest .hookimpl
169
192
def pytest_runtestloop (self , session : pytest .Session ) -> bool :
170
193
self .log ("entering main loop" )
171
194
self .channel .setcallback (self .handle_command , endmarker = Marker .SHUTDOWN )
172
- self .nextitem_index = self ._get_next_item_index ()
195
+ self .nextitem_index = self .torun . get ()
173
196
while self .nextitem_index is not Marker .SHUTDOWN :
174
197
self .run_one_test ()
175
198
if session .shouldfail or session .shouldstop :
@@ -179,7 +202,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool:
179
202
def run_one_test (self ) -> None :
180
203
assert isinstance (self .nextitem_index , int )
181
204
self .item_index = self .nextitem_index
182
- self .nextitem_index = self ._get_next_item_index ()
205
+ self .nextitem_index = self .torun . get ()
183
206
184
207
items = self .session .items
185
208
item = items [self .item_index ]
0 commit comments