Skip to content

Commit df68d87

Browse files
committed
Implement concurrent limits based on item marks
This change enables users to limit individual test concurrency based on mark configuration. For example, lets say that there are 10 tests, and 8 of them can be executed concurrently, but 2 of them use some shared resource and can't be executed concurrenctly. By using the new option `--max-asyncio-tasks-by-mark`, its now possible to configure the task scheduler to execute the 2 tasks in series, but the remaining tasks 8 tasks would still be scheduled concurrently.
1 parent 78fd29d commit df68d87

File tree

3 files changed

+475
-16
lines changed

3 files changed

+475
-16
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__

pytest_asyncio_cooperative/plugin.py

Lines changed: 168 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,32 @@ def pytest_addoption(parser):
2727
default=100,
2828
)
2929

30+
parser.addoption(
31+
"--max-asyncio-tasks-by-mark",
32+
action="store",
33+
default=None,
34+
help="asyncio: maximum number of tasks to run concurrently for each mark (mark1,mark2,..=int pairs, space separated)",
35+
)
36+
parser.addini(
37+
"max_asyncio_tasks_by_mark",
38+
"asyncio: asyncio: maximum number of tasks to run concurrently for each mark (mark1,mark2,..=int pairs, space separated)",
39+
default=None,
40+
)
41+
42+
parser.addoption(
43+
"--max-asyncio-tasks-by-mark-remainder",
44+
action="store",
45+
default=0,
46+
help="asyncio: maximum number of tasks to run concurrently for tasks that didn't match any "
47+
"marks in `--max-asyncio-tasks-by-mark` (default 0, unlimited)",
48+
)
49+
parser.addini(
50+
"max_asyncio_tasks_by_mark_remainder",
51+
help="asyncio: maximum number of tasks to run concurrently for tasks that didn't match any "
52+
"marks in `--max-asyncio-tasks-by-mark` (default 0, unlimited)",
53+
default=0,
54+
)
55+
3056
parser.addoption(
3157
"--asyncio-task-timeout",
3258
action="store",
@@ -171,15 +197,128 @@ def item_to_task(item):
171197
raise NotCoroutine
172198

173199

200+
class MarkLimits:
201+
# To support multiple marks in the same group separate by comma, a shared group object
202+
# is created and inserted into the groups dictionary multiple times for each mark in
203+
# the group. When a test has multiple marks that belong to the same group, each mark
204+
# will be added multiple times to the set, but the result will be that the number
205+
# of active items in the group will only increase by 1. A simple ref count is not enough
206+
# in this situation, because it would overcount how many items are actively running.
207+
class Group:
208+
def __init__(self, max):
209+
self.max = max
210+
self.items = set()
211+
212+
def __init__(self):
213+
self.groups = {}
214+
self.remainder = 0
215+
self.remainder_max = 0
216+
217+
def update_config(self, session):
218+
if session.config.getoption("--max-asyncio-tasks-by-mark"):
219+
self.groups = MarkLimits.parse_max_tasks_by_mark(
220+
session.config.getoption("--max-asyncio-tasks-by-mark"),
221+
"--max-asyncio-tasks-by-mark",
222+
)
223+
else:
224+
self.groups = MarkLimits.parse_max_tasks_by_mark(
225+
session.config.getini("max_asyncio_tasks_by_mark"),
226+
"max_asyncio_tasks_by_mark",
227+
)
228+
229+
self.remainder_max = int(
230+
session.config.getoption("--max-asyncio-tasks-by-mark-remainder")
231+
or session.config.getini("max_asyncio_tasks_by_mark_remainder")
232+
)
233+
234+
@staticmethod
235+
def parse_max_tasks_by_mark(config_value, debug_prefix):
236+
if not config_value:
237+
return {}
238+
239+
result = {}
240+
pairs = config_value.split()
241+
for pair in pairs:
242+
columns = pair.split("=")
243+
if len(columns) > 2:
244+
assert False, f"`{debug_prefix}`: too many `=` in `{pair}`"
245+
246+
try:
247+
max_tasks = int(columns[1])
248+
except ValueError:
249+
assert False, f"`{debug_prefix}`: expected integer in `{pair}`"
250+
251+
group = MarkLimits.Group(max_tasks)
252+
for mark in columns[0].split(","):
253+
assert (
254+
mark not in result
255+
), f"`{debug_prefix}`: multiple occurences of mark `{mark}`"
256+
result[mark] = group
257+
258+
return result
259+
260+
def would_exceed_max_marks(self, item):
261+
for mark in item.own_markers:
262+
group = self.groups.get(mark.name)
263+
if group and len(group.items) >= group.max:
264+
return True
265+
266+
if self.remainder_max and self.remainder >= self.remainder_max:
267+
return True
268+
269+
return False
270+
271+
def update_active_marks(self, item, add):
272+
matched_mark = False
273+
274+
for mark in item.own_markers:
275+
group = self.groups.get(mark.name)
276+
if group:
277+
# Not returning here, because each mark must be accounted for,
278+
# not just the first matching one.
279+
matched_mark = True
280+
if add:
281+
group.items.add(item)
282+
elif item in group.items:
283+
group.items.remove(item)
284+
285+
# Make sure not to overcount the remainder
286+
if matched_mark:
287+
return
288+
289+
if self.remainder_max:
290+
if add:
291+
self.remainder += 1
292+
else:
293+
self.remainder -= 1
294+
295+
296+
def get_coro(task):
297+
if sys_version_info >= (3, 8):
298+
return task.get_coro()
299+
else:
300+
return task._coro
301+
302+
303+
def get_item_by_coro(task, item_by_coro):
304+
if isinstance(task, asyncio.Task):
305+
return item_by_coro[get_coro(task)]
306+
else:
307+
return item_by_coro[task]
308+
309+
174310
def _run_test_loop(tasks, session, run_tests):
175311
max_tasks = int(
176312
session.config.getoption("--max-asyncio-tasks")
177313
or session.config.getini("max_asyncio_tasks")
178314
)
179315

316+
mark_limits = MarkLimits()
317+
mark_limits.update_config(session)
318+
180319
loop = asyncio.new_event_loop()
181320
try:
182-
task = run_tests(tasks, int(max_tasks))
321+
task = run_tests(tasks, int(max_tasks), mark_limits)
183322
loop.run_until_complete(task)
184323
finally:
185324
loop.close()
@@ -234,15 +373,32 @@ def pytest_runtestloop(session):
234373
else:
235374
regular_items.append(item)
236375

237-
def get_coro(task):
238-
if sys_version_info >= (3, 8):
239-
return task.get_coro()
240-
else:
241-
return task._coro
376+
async def run_tests(tasks, max_tasks: int, mark_limits):
377+
sidelined_tasks = tasks
378+
tasks = []
379+
380+
def enqueue_tasks():
381+
while len(tasks) < max_tasks:
382+
remove_index = None
383+
384+
for index, task in enumerate(sidelined_tasks):
385+
item = get_item_by_coro(task, item_by_coro)
386+
if not mark_limits.would_exceed_max_marks(item):
387+
mark_limits.update_active_marks(item, add=True)
388+
tasks.append(task)
389+
remove_index = index
390+
break
391+
392+
if remove_index == None:
393+
# No available tasks were found, give up control.
394+
break
395+
396+
# Removing element from start/middle of the list is actually
397+
# quite fast compared to iterating to the end of the list in
398+
# the above loop.
399+
sidelined_tasks.pop(remove_index)
242400

243-
async def run_tests(tasks, max_tasks: int):
244-
sidelined_tasks = tasks[max_tasks:]
245-
tasks = tasks[:max_tasks]
401+
enqueue_tasks()
246402

247403
task_timeout = int(
248404
session.config.getoption("--asyncio-task-timeout")
@@ -260,10 +416,7 @@ async def run_tests(tasks, max_tasks: int):
260416
# Mark when the task was started
261417
earliest_enqueue_time = time.time()
262418
for task in tasks:
263-
if isinstance(task, asyncio.Task):
264-
item = item_by_coro[get_coro(task)]
265-
else:
266-
item = item_by_coro[task]
419+
item = get_item_by_coro(task, item_by_coro)
267420
if not hasattr(item, "enqueue_time"):
268421
item.enqueue_time = time.time()
269422
earliest_enqueue_time = min(item.enqueue_time, earliest_enqueue_time)
@@ -293,6 +446,7 @@ async def run_tests(tasks, max_tasks: int):
293446

294447
for result in done:
295448
item = item_by_coro[get_coro(result)]
449+
mark_limits.update_active_marks(item, add=False)
296450

297451
# Flakey tests will be run again if they failed
298452
# TODO: add retry count
@@ -317,9 +471,7 @@ async def run_tests(tasks, max_tasks: int):
317471

318472
completed.append(result)
319473

320-
if sidelined_tasks:
321-
if len(tasks) < max_tasks:
322-
tasks.append(sidelined_tasks.pop(0))
474+
enqueue_tasks()
323475

324476
return completed
325477

0 commit comments

Comments
 (0)