Skip to content

Commit eb4388c

Browse files
committed
Make 'steal' command atomic
Either unschedule all requested tests, or none if it's not possible - if some of the requested tests have already been processed by the time the request arrives. It may happen if the worker runs tests faster than the controller receives and processes status updates. But in this case maybe it's just better to let the worker keep running. This is a prerequisite for group/scope support in worksteal scheduler - so they won't be broken up incorrectly. This change could break schedulers that use "steal" command. However: 1) worksteal scheduler doesn't need any adjustments. 2) I'm not aware of any external schedulers relying on this command yet. So I think it's better to keep the protocol simple, not complicate it for imaginary compatibility with some unknown and likely non-existent schedulers.
1 parent 9c24f0f commit eb4388c

File tree

3 files changed

+59
-31
lines changed

3 files changed

+59
-31
lines changed

changelog/1144.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make "steal" command atomic - make it unschedule either all requested tests or none.

src/xdist/remote.py

+54-31
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,19 @@
88

99
from __future__ import annotations
1010

11+
import collections
1112
import contextlib
1213
import enum
1314
import os
1415
import sys
1516
import time
1617
from typing import Any
1718
from typing import Generator
19+
from typing import Iterable
1820
from typing import Literal
1921
from typing import Sequence
2022
from typing import TypedDict
23+
from typing import Union
2124
import warnings
2225

2326
from _pytest.config import _prepareconfig
@@ -66,7 +69,44 @@ def worker_title(title: str) -> None:
6669

6770
class Marker(enum.Enum):
6871
SHUTDOWN = 0
69-
QUEUE_REPLACED = 1
72+
73+
74+
class TestQueue:
75+
"""A simple queue that can be inspected and modified while the lock is held."""
76+
77+
Item = Union[int, Literal[Marker.SHUTDOWN]]
78+
79+
def __init__(self, execmodel: execnet.gateway_base.ExecModel):
80+
self._items: collections.deque[TestQueue.Item] = collections.deque()
81+
self._lock = execmodel.RLock() # type: ignore[no-untyped-call]
82+
self._has_items_event = execmodel.Event()
83+
84+
def get(self) -> Item:
85+
while True:
86+
with self.lock() as locked_items:
87+
if locked_items:
88+
return locked_items.popleft()
89+
90+
self._has_items_event.wait()
91+
92+
def put(self, item: Item) -> None:
93+
with self.lock() as locked_items:
94+
locked_items.append(item)
95+
96+
def replace(self, iterable: Iterable[Item]) -> None:
97+
with self.lock():
98+
self._items = collections.deque(iterable)
99+
100+
@contextlib.contextmanager
101+
def lock(self) -> Generator[collections.deque[Item], None, None]:
102+
with self._lock:
103+
try:
104+
yield self._items
105+
finally:
106+
if self._items:
107+
self._has_items_event.set()
108+
else:
109+
self._has_items_event.clear()
70110

71111

72112
class WorkerInteractor:
@@ -77,22 +117,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
77117
self.testrunuid = workerinput["testrunuid"]
78118
self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug)
79119
self.channel = channel
80-
self.torun = self._make_queue()
120+
self.torun = TestQueue(self.channel.gateway.execmodel)
81121
self.nextitem_index: int | None | Literal[Marker.SHUTDOWN] = None
82122
config.pluginmanager.register(self)
83123

84-
def _make_queue(self) -> Any:
85-
return self.channel.gateway.execmodel.queue.Queue()
86-
87-
def _get_next_item_index(self) -> int | Literal[Marker.SHUTDOWN]:
88-
"""Gets the next item from test queue. Handles the case when the queue
89-
is replaced concurrently in another thread.
90-
"""
91-
result = self.torun.get()
92-
while result is Marker.QUEUE_REPLACED:
93-
result = self.torun.get()
94-
return result # type: ignore[no-any-return]
95-
96124
def sendevent(self, name: str, **kwargs: object) -> None:
97125
self.log("sending", name, kwargs)
98126
self.channel.send((name, kwargs))
@@ -146,30 +174,25 @@ def handle_command(
146174
self.steal(kwargs["indices"])
147175

148176
def steal(self, indices: Sequence[int]) -> None:
149-
indices_set = set(indices)
150-
stolen = []
151-
152-
old_queue, self.torun = self.torun, self._make_queue()
153-
154-
def old_queue_get_nowait_noraise() -> int | None:
155-
with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty):
156-
return old_queue.get_nowait() # type: ignore[no-any-return]
157-
return None
158-
159-
for i in iter(old_queue_get_nowait_noraise, None):
160-
if i in indices_set:
161-
stolen.append(i)
177+
with self.torun.lock() as locked_queue:
178+
requested_set = set(indices)
179+
stolen = list(item for item in locked_queue if item in requested_set)
180+
181+
# Stealing only if all requested tests are still pending
182+
if len(stolen) == len(requested_set):
183+
self.torun.replace(
184+
item for item in locked_queue if item not in requested_set
185+
)
162186
else:
163-
self.torun.put(i)
187+
stolen = []
164188

165189
self.sendevent("unscheduled", indices=stolen)
166-
old_queue.put(Marker.QUEUE_REPLACED)
167190

168191
@pytest.hookimpl
169192
def pytest_runtestloop(self, session: pytest.Session) -> bool:
170193
self.log("entering main loop")
171194
self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN)
172-
self.nextitem_index = self._get_next_item_index()
195+
self.nextitem_index = self.torun.get()
173196
while self.nextitem_index is not Marker.SHUTDOWN:
174197
self.run_one_test()
175198
if session.shouldfail or session.shouldstop:
@@ -179,7 +202,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool:
179202
def run_one_test(self) -> None:
180203
assert isinstance(self.nextitem_index, int)
181204
self.item_index = self.nextitem_index
182-
self.nextitem_index = self._get_next_item_index()
205+
self.nextitem_index = self.torun.get()
183206

184207
items = self.session.items
185208
item = items[self.item_index]

testing/test_remote.py

+4
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,10 @@ def test_func4(): pass
267267

268268
worker.sendcommand("steal", indices=[1, 2])
269269
ev = worker.popevent("unscheduled")
270+
assert ev.kwargs["indices"] == []
271+
272+
worker.sendcommand("steal", indices=[2])
273+
ev = worker.popevent("unscheduled")
270274
assert ev.kwargs["indices"] == [2]
271275

272276
reports = [

0 commit comments

Comments
 (0)