Skip to content

Commit 19dc897

Browse files
amezinnicoddemus
andcommitted
Make 'steal' command atomic
Either unschedule all requested tests, or none if it's not possible - if some of the requested tests have already been processed by the time the request arrives. It may happen if the worker runs tests faster than the controller receives and processes status updates. But in this case maybe it's just better to let the worker keep running. This is a prerequisite for group/scope support in worksteal scheduler - so they won't be broken up incorrectly. This change could break schedulers that use "steal" command. However: 1) worksteal scheduler doesn't need any adjustments. 2) I'm not aware of any external schedulers relying on this command yet. So I think it's better to keep the protocol simple, not complicate it for imaginary compatibility with some unknown and likely non-existent schedulers. Co-authored-by: Bruno Oliveira <[email protected]>
1 parent 34c5549 commit 19dc897

File tree

3 files changed

+69
-28
lines changed

3 files changed

+69
-28
lines changed

changelog/1144.feature

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
The internal `steal` command is now atomic - it unschedules either all requested tests or none.
2+
3+
This is a prerequisite for group/scope support in the `worksteal` scheduler, so test groups won't be broken up incorrectly.

src/xdist/remote.py

+60-28
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,19 @@
88

99
from __future__ import annotations
1010

11+
import collections
1112
import contextlib
1213
import enum
1314
import os
1415
import sys
1516
import time
1617
from typing import Any
1718
from typing import Generator
19+
from typing import Iterable
1820
from typing import Literal
1921
from typing import Sequence
2022
from typing import TypedDict
23+
from typing import Union
2124
import warnings
2225

2326
from _pytest.config import _prepareconfig
@@ -66,7 +69,44 @@ def worker_title(title: str) -> None:
6669

6770
class Marker(enum.Enum):
6871
SHUTDOWN = 0
69-
QUEUE_REPLACED = 1
72+
73+
74+
class TestQueue:
75+
"""A simple queue that can be inspected and modified while the lock is held via the ``lock()`` method."""
76+
77+
Item = Union[int, Literal[Marker.SHUTDOWN]]
78+
79+
def __init__(self, execmodel: execnet.gateway_base.ExecModel):
80+
self._items: collections.deque[TestQueue.Item] = collections.deque()
81+
self._lock = execmodel.RLock() # type: ignore[no-untyped-call]
82+
self._has_items_event = execmodel.Event()
83+
84+
def get(self) -> Item:
85+
while True:
86+
with self.lock() as locked_items:
87+
if locked_items:
88+
return locked_items.popleft()
89+
90+
self._has_items_event.wait()
91+
92+
def put(self, item: Item) -> None:
93+
with self.lock() as locked_items:
94+
locked_items.append(item)
95+
96+
def replace(self, iterable: Iterable[Item]) -> None:
97+
with self.lock():
98+
self._items = collections.deque(iterable)
99+
100+
@contextlib.contextmanager
101+
def lock(self) -> Generator[collections.deque[Item], None, None]:
102+
with self._lock:
103+
try:
104+
yield self._items
105+
finally:
106+
if self._items:
107+
self._has_items_event.set()
108+
else:
109+
self._has_items_event.clear()
70110

71111

72112
class WorkerInteractor:
@@ -77,22 +117,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
77117
self.testrunuid = workerinput["testrunuid"]
78118
self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug)
79119
self.channel = channel
80-
self.torun = self._make_queue()
120+
self.torun = TestQueue(self.channel.gateway.execmodel)
81121
self.nextitem_index: int | None | Literal[Marker.SHUTDOWN] = None
82122
config.pluginmanager.register(self)
83123

84-
def _make_queue(self) -> Any:
85-
return self.channel.gateway.execmodel.queue.Queue()
86-
87-
def _get_next_item_index(self) -> int | Literal[Marker.SHUTDOWN]:
88-
"""Gets the next item from test queue. Handles the case when the queue
89-
is replaced concurrently in another thread.
90-
"""
91-
result = self.torun.get()
92-
while result is Marker.QUEUE_REPLACED:
93-
result = self.torun.get()
94-
return result # type: ignore[no-any-return]
95-
96124
def sendevent(self, name: str, **kwargs: object) -> None:
97125
self.log("sending", name, kwargs)
98126
self.channel.send((name, kwargs))
@@ -146,30 +174,34 @@ def handle_command(
146174
self.steal(kwargs["indices"])
147175

148176
def steal(self, indices: Sequence[int]) -> None:
149-
indices_set = set(indices)
150-
stolen = []
177+
"""
178+
Remove tests from the queue.
151179
152-
old_queue, self.torun = self.torun, self._make_queue()
180+
Removes either all requested tests, or none, if some of these tests
181+
are not in the queue (for example, if they were processed already).
153182
154-
def old_queue_get_nowait_noraise() -> int | None:
155-
with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty):
156-
return old_queue.get_nowait() # type: ignore[no-any-return]
157-
return None
183+
:param indices: indices of the tests to remove.
184+
"""
185+
requested_set = set(indices)
186+
187+
with self.torun.lock() as locked_queue:
188+
stolen = list(item for item in locked_queue if item in requested_set)
158189

159-
for i in iter(old_queue_get_nowait_noraise, None):
160-
if i in indices_set:
161-
stolen.append(i)
190+
# Stealing only if all requested tests are still pending
191+
if len(stolen) == len(requested_set):
192+
self.torun.replace(
193+
item for item in locked_queue if item not in requested_set
194+
)
162195
else:
163-
self.torun.put(i)
196+
stolen = []
164197

165198
self.sendevent("unscheduled", indices=stolen)
166-
old_queue.put(Marker.QUEUE_REPLACED)
167199

168200
@pytest.hookimpl
169201
def pytest_runtestloop(self, session: pytest.Session) -> bool:
170202
self.log("entering main loop")
171203
self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN)
172-
self.nextitem_index = self._get_next_item_index()
204+
self.nextitem_index = self.torun.get()
173205
while self.nextitem_index is not Marker.SHUTDOWN:
174206
self.run_one_test()
175207
if session.shouldfail or session.shouldstop:
@@ -179,7 +211,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool:
179211
def run_one_test(self) -> None:
180212
assert isinstance(self.nextitem_index, int)
181213
self.item_index = self.nextitem_index
182-
self.nextitem_index = self._get_next_item_index()
214+
self.nextitem_index = self.torun.get()
183215

184216
items = self.session.items
185217
item = items[self.item_index]

testing/test_remote.py

+6
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,12 @@ def test_func4(): pass
267267

268268
worker.sendcommand("steal", indices=[1, 2])
269269
ev = worker.popevent("unscheduled")
270+
# Cannot steal index 1 because it is completed already, so do not steal any.
271+
assert ev.kwargs["indices"] == []
272+
273+
# Index 2 can be stolen, as it is still pending.
274+
worker.sendcommand("steal", indices=[2])
275+
ev = worker.popevent("unscheduled")
270276
assert ev.kwargs["indices"] == [2]
271277

272278
reports = [

0 commit comments

Comments
 (0)