Skip to content

Commit b809187

Browse files
committed
mypy fixes
1 parent 05b9494 commit b809187

File tree

7 files changed

+65
-22
lines changed

7 files changed

+65
-22
lines changed

src/xdist/dsession.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
import traceback
99
from typing import Any
10+
from typing import Callable
1011
from typing import Sequence
1112
import warnings
1213

@@ -60,14 +61,14 @@ def __init__(self, config: pytest.Config) -> None:
6061
self._failed_collection_errors: dict[object, bool] = {}
6162
self._active_nodes: set[WorkerController] = set()
6263
self._failed_nodes_count = 0
63-
self.saved_put = None
64+
self.saved_put: Callable[[tuple[str, dict[str, Any]]], None]
6465
self.remake_nodes = False
6566
self.ready_to_run_tests = False
6667
self._max_worker_restart = get_default_max_worker_restart(self.config)
6768
# summary message to print at the end of the session
6869
self._summary_report: str | None = None
6970
self.terminal = config.pluginmanager.getplugin("terminalreporter")
70-
self.worker_status: dict[WorkerController, str] = {}
71+
self.worker_status: dict[str, str] = {}
7172
if self.terminal:
7273
self.trdist = TerminalDistReporter(config)
7374
config.pluginmanager.register(self.trdist, "terminaldistreporter")
@@ -180,63 +181,71 @@ def loop_once(self) -> None:
180181
self.triggershutdown()
181182

182183

183-
def is_node_finishing(self, node: WorkerController):
184+
def is_node_finishing(self, node: WorkerController) -> bool:
184185
"""Check if a test worker is considered to be finishing.
185186
186187
Evaluate whether it's on its last test, or if no tests are pending.
187188
"""
189+
assert self.sched is not None
188190
pending = self.sched.node2pending.get(node)
189191
return pending is not None and len(pending) < 2
190192

191193

192-
def is_node_clear(self, node: WorkerController):
194+
def is_node_clear(self, node: WorkerController) -> bool:
193195
"""Check if a test worker has no pending tests."""
196+
assert self.sched is not None
194197
pending = self.sched.node2pending.get(node)
195198
return pending is None or len(pending) == 0
196199

197200

198-
def are_all_nodes_finishing(self):
201+
def are_all_nodes_finishing(self) -> bool:
199202
"""Check if all workers are finishing (See 'is_node_finishing' above)."""
203+
assert self.sched is not None
200204
return all(self.is_node_finishing(node) for node in self.sched.nodes)
201205

202206

203-
def are_all_nodes_done(self):
207+
def are_all_nodes_done(self) -> bool:
204208
"""Check if all nodes have reported to finish."""
205209
return all(s == "finished" for s in self.worker_status.values())
206210

207211

208-
def are_all_active_nodes_collected(self):
212+
def are_all_active_nodes_collected(self) -> bool:
209213
"""Check if all nodes have reported collection to be complete."""
210214
if not all(n.gateway.id in self.worker_status for n in self._active_nodes):
211215
return False
212216
return all(self.worker_status[n.gateway.id] == "collected" for n in self._active_nodes)
213217

214218

215-
def reset_nodes_if_needed(self):
219+
def reset_nodes_if_needed(self) -> None:
220+
assert self.sched is not None
216221
if self.are_all_nodes_finishing() and self.ready_to_run_tests and not self.sched.do_resched:
217222
self.reset_nodes()
218223

219224

220-
def reset_nodes(self):
225+
def reset_nodes(self) -> None:
221226
"""Issue shutdown notices to workers for rescheduling purposes."""
227+
assert self.sched is not None
222228
if len(self.sched.pending) != 0:
223229
self.remake_nodes = True
224230
for node in self.sched.nodes:
225231
if self.is_node_finishing(node):
226232
node.shutdown()
227233

228234

229-
def reschedule(self):
235+
def reschedule(self) -> None:
230236
"""Reschedule tests."""
237+
assert self.sched is not None
231238
self.sched.do_resched = False
232239
self.sched.check_schedule(self.sched.nodes[0], 1.0, True)
233240

234241

235-
def prepare_for_reschedule(self):
242+
def prepare_for_reschedule(self) -> None:
236243
"""Update test workers and their status tracking so rescheduling is ready."""
244+
assert self.sched is not None
237245
self.remake_nodes = False
238246
num_workers = self.sched.dist_groups[self.sched.pending_groups[0]]['group_workers']
239247
self.trdist._status = {}
248+
assert self.nodemanager is not None
240249
new_nodes = self.nodemanager.setup_nodes(self.saved_put, num_workers)
241250
self.worker_status = {}
242251
self._active_nodes = set()
@@ -310,7 +319,7 @@ def worker_workerfinished(self, node: WorkerController) -> None:
310319
assert not crashitem, (crashitem, node)
311320
self._active_nodes.remove(node)
312321

313-
def update_worker_status(self, node, status):
322+
def update_worker_status(self, node: WorkerController, status: str) -> None:
314323
"""Track the worker status.
315324
316325
Can be used at callbacks like 'worker_workerfinished' so we remember wchic event

src/xdist/scheduler/customgroup.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def remove_pending_tests_from_node(
205205
) -> None:
206206
raise NotImplementedError()
207207

208-
def check_schedule(self, node: WorkerController, duration: float = 0, from_dsession=False) -> None:
208+
def check_schedule(self, node: WorkerController, duration: float = 0, from_dsession: bool = False) -> None:
209209
"""Maybe schedule new items on the node.
210210
211211
If there are any globally pending nodes left then this will
@@ -228,7 +228,7 @@ def check_schedule(self, node: WorkerController, duration: float = 0, from_dsess
228228
dist_group_key = self.pending_groups.pop(0)
229229
dist_group = self.dist_groups[dist_group_key]
230230
nodes = cycle(self.nodes[0:dist_group['group_workers']])
231-
schedule_log = {n.gateway.id:[] for n in self.nodes[0:dist_group['group_workers']]}
231+
schedule_log: dict[str, Any] = {n.gateway.id:[] for n in self.nodes[0:dist_group['group_workers']]}
232232
for _ in range(len(dist_group['test_indices'])):
233233
n = next(nodes)
234234
#needs cleaner way to be identified
@@ -242,7 +242,7 @@ def check_schedule(self, node: WorkerController, duration: float = 0, from_dsess
242242
self.report_line(message)
243243

244244
else:
245-
pending = self.node2pending.get(node)
245+
pending = self.node2pending.get(node, [])
246246
if len(pending) < 2:
247247
self.report_line(
248248
f"[-] [csg] Shutting down {node.workerinput['workerid']} because only one case is pending"
@@ -306,7 +306,7 @@ def schedule(self) -> None:
306306
if not self.collection:
307307
return
308308

309-
dist_groups = {}
309+
dist_groups: dict[str, dict[Any, Any]] = {}
310310

311311
if self.is_first_time:
312312
for i, test in enumerate(self.collection):
@@ -343,7 +343,7 @@ def schedule(self) -> None:
343343
dist_group_key = self.pending_groups.pop(0)
344344
dist_group = self.dist_groups[dist_group_key]
345345
nodes = cycle(self.nodes[0:dist_group['group_workers']])
346-
schedule_log = {n.gateway.id: [] for n in self.nodes[0:dist_group['group_workers']]}
346+
schedule_log: dict[str, Any] = {n.gateway.id: [] for n in self.nodes[0:dist_group['group_workers']]}
347347
for _ in range(len(dist_group['test_indices'])):
348348
n = next(nodes)
349349
# needs cleaner way to be identified
@@ -362,7 +362,7 @@ def _send_tests(self, node: WorkerController, num: int) -> None:
362362
self.node2pending[node].extend(tests_per_node)
363363
node.send_runtest_some(tests_per_node)
364364

365-
def _send_tests_group(self, node: WorkerController, num: int, dist_group_key) -> None:
365+
def _send_tests_group(self, node: WorkerController, num: int, dist_group_key: str) -> None:
366366
tests_per_node = self.dist_groups[dist_group_key]['pending_indices'][:num]
367367
if tests_per_node:
368368
del self.dist_groups[dist_group_key]['pending_indices'][:num]

src/xdist/scheduler/each.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from typing import Any
34
from typing import Sequence
45

56
import pytest
@@ -29,6 +30,10 @@ def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
2930
self.numnodes = len(parse_spec_config(config))
3031
self.node2collection: dict[WorkerController, list[str]] = {}
3132
self.node2pending: dict[WorkerController, list[int]] = {}
33+
self.do_resched: bool = False
34+
self.pending: list[int] = []
35+
self.dist_groups: dict[str, Any] = {}
36+
self.pending_groups: list[str] = []
3237
self._started: list[WorkerController] = []
3338
self._removed2pending: dict[WorkerController, list[int]] = {}
3439
if log is None:
@@ -106,6 +111,9 @@ def add_node_collection(
106111
self.node2pending[node] = pending
107112
break
108113

114+
def check_schedule(self, node: WorkerController, duration: float = 0, from_dsession: bool = False) -> None:
115+
raise NotImplementedError()
116+
109117
def mark_test_complete(
110118
self, node: WorkerController, item_index: int, duration: float = 0
111119
) -> None:

src/xdist/scheduler/load.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from itertools import cycle
4+
from typing import Any
45
from typing import Sequence
56

67
import pytest
@@ -62,6 +63,9 @@ def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
6263
self.node2collection: dict[WorkerController, list[str]] = {}
6364
self.node2pending: dict[WorkerController, list[int]] = {}
6465
self.pending: list[int] = []
66+
self.do_resched: bool = False
67+
self.dist_groups: dict[str, Any] = {}
68+
self.pending_groups: list[str] = []
6569
self.collection: list[str] | None = None
6670
if log is None:
6771
self.log = Producer("loadsched")
@@ -176,7 +180,8 @@ def remove_pending_tests_from_node(
176180
) -> None:
177181
raise NotImplementedError()
178182

179-
def check_schedule(self, node: WorkerController, duration: float = 0) -> None:
183+
184+
def check_schedule(self, node: WorkerController, duration: float = 0, from_dsession: bool = False) -> None:
180185
"""Maybe schedule new items on the node.
181186
182187
If there are any globally pending nodes left then this will

src/xdist/scheduler/loadscope.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections import OrderedDict
4+
from typing import Any
45
from typing import NoReturn
56
from typing import Sequence
67

@@ -93,11 +94,14 @@ class LoadScopeScheduling:
9394
def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
9495
self.numnodes = len(parse_spec_config(config))
9596
self.collection: list[str] | None = None
96-
97+
self.node2pending: dict[WorkerController, list[int]] = {}
9798
self.workqueue: OrderedDict[str, dict[str, bool]] = OrderedDict()
9899
self.assigned_work: dict[WorkerController, dict[str, dict[str, bool]]] = {}
99100
self.registered_collections: dict[WorkerController, list[str]] = {}
100-
101+
self.do_resched: bool = False
102+
self.pending: list[int] = []
103+
self.dist_groups: dict[str, Any] = {}
104+
self.pending_groups: list[str] = []
101105
if log is None:
102106
self.log = Producer("loadscopesched")
103107
else:
@@ -163,6 +167,9 @@ def add_node(self, node: WorkerController) -> None:
163167
assert node not in self.assigned_work
164168
self.assigned_work[node] = {}
165169

170+
def check_schedule(self, node: WorkerController, duration: float = 0, from_dsession: bool = False) -> None:
171+
raise NotImplementedError()
172+
166173
def remove_node(self, node: WorkerController) -> str | None:
167174
"""Remove a node from the scheduler.
168175

src/xdist/scheduler/protocol.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
from __future__ import annotations
22

3+
from typing import Any
34
from typing import Protocol
45
from typing import Sequence
56

67
from xdist.workermanage import WorkerController
78

89

910
class Scheduling(Protocol):
11+
node2pending: Any
12+
do_resched: bool
13+
pending: list[int]
14+
dist_groups: dict[str, Any]
15+
pending_groups: list[str]
16+
1017
@property
1118
def nodes(self) -> list[WorkerController]: ...
1219

@@ -27,6 +34,8 @@ def add_node_collection(
2734
collection: Sequence[str],
2835
) -> None: ...
2936

37+
def check_schedule(self, node: WorkerController, duration: float = 0, from_dsession: bool = False) -> None: ...
38+
3039
def mark_test_complete(
3140
self,
3241
node: WorkerController,

src/xdist/scheduler/worksteal.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from typing import Any
34
from typing import NamedTuple
45
from typing import Sequence
56

@@ -70,6 +71,9 @@ def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
7071
self.node2pending: dict[WorkerController, list[int]] = {}
7172
self.pending: list[int] = []
7273
self.collection: list[str] | None = None
74+
self.do_resched: bool = False
75+
self.dist_groups: dict[str, Any] = {}
76+
self.pending_groups: list[str] = []
7377
if log is None:
7478
self.log = Producer("workstealsched")
7579
else:
@@ -193,7 +197,8 @@ def remove_pending_tests_from_node(
193197
self.pending.extend(indices)
194198
self.check_schedule()
195199

196-
def check_schedule(self) -> None:
200+
def check_schedule(self, node: WorkerController | None = None, duration: float = 0, from_dsession: bool = False
201+
) -> None:
197202
"""Reschedule tests/perform load balancing."""
198203
nodes_up = [
199204
NodePending(node, pending)

0 commit comments

Comments
 (0)