Skip to content

Commit 815e7bc

Browse files
committed
Implement concurrent limits based on item marks
This change enables users to limit individual test concurrency based on mark configuration. For example, lets say that there are 10 tests, and 8 of them can be executed concurrently, but 2 of them use some shared resource and can't be executed concurrenctly. By using the new option `--max-asyncio-tasks-by-mark`, its now possible to configure the task scheduler to execute the 2 tasks in series, but the remaining tasks 8 tasks would still be scheduled concurrently.
1 parent 01db22d commit 815e7bc

File tree

8 files changed

+302
-39
lines changed

8 files changed

+302
-39
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__

pytest_asyncio_cooperative/fixtures.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _get_fixture(item, arg_name, fixture=None):
2323
if fixture:
2424
try:
2525
item._request.param = item._pyfuncitem.callspec.params[fixture.argname]
26-
except (AttributeError, KeyError) :
26+
except (AttributeError, KeyError):
2727
pass
2828

2929
return item._request
@@ -62,16 +62,19 @@ async def fill_fixtures(item):
6262

6363
if fixture.scope not in ["function", "module", "session"]:
6464
raise Exception(f"{fixture.scope} scope not supported")
65-
65+
6666
fixtures.append(fixture)
6767
are_autouse.append(is_autouse)
68-
68+
6969
# Fill fixtures concurrently
7070
fill_results = await asyncio.gather(
71-
*(fill_fixture_fixtures(item._fixtureinfo, fixture, item) for fixture in fixtures)
71+
*(
72+
fill_fixture_fixtures(item._fixtureinfo, fixture, item)
73+
for fixture in fixtures
74+
)
7275
)
7376

74-
for ((value, extra_teardowns), is_autouse) in zip(fill_results, are_autouse):
77+
for (value, extra_teardowns), is_autouse in zip(fill_results, are_autouse):
7578
teardowns.extend(extra_teardowns)
7679

7780
if not is_autouse:
@@ -355,7 +358,9 @@ async def fill_fixture_fixtures(_fixtureinfo, fixture, item):
355358
):
356359
return await _make_coroutine_fixture(_fixtureinfo, fixture, item)
357360

358-
elif inspect.isgeneratorfunction(fixture.func) or isinstance(fixture.func, CachedGen):
361+
elif inspect.isgeneratorfunction(fixture.func) or isinstance(
362+
fixture.func, CachedGen
363+
):
359364
return await _make_regular_generator_fixture(_fixtureinfo, fixture, item)
360365

361366
elif inspect.isfunction(fixture.func):

pytest_asyncio_cooperative/plugin.py

Lines changed: 109 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,18 @@ def pytest_addoption(parser):
2727
default=100,
2828
)
2929

30+
parser.addoption(
31+
"--max-asyncio-tasks-by-mark",
32+
action="store",
33+
default=None,
34+
help="asyncio: maximum number of tasks to run concurrently for each mark (mark=int pairs)",
35+
)
36+
parser.addini(
37+
"max_asyncio_tasks_by_mark",
38+
"asyncio: asyncio: maximum number of tasks to run concurrently for each mark (mark=int pairs)",
39+
default=None,
40+
)
41+
3042
parser.addoption(
3143
"--asyncio-task-timeout",
3244
action="store",
@@ -171,15 +183,77 @@ def item_to_task(item):
171183
raise NotCoroutine
172184

173185

186+
def parse_max_tasks_by_mark(config, debug_prefix):
187+
if not config:
188+
return {}
189+
190+
result = {}
191+
pairs = config.split(",")
192+
for pair in pairs:
193+
columns = pair.split("=")
194+
if len(columns) > 2:
195+
assert False, f"`{debug_prefix}`: too many `=` in `{pair}`"
196+
try:
197+
max_tasks = int(columns[1])
198+
except ValueError:
199+
assert False, f"`{debug_prefix}`: expected integer in `{pair}`"
200+
201+
mark = columns[0]
202+
result[mark] = max_tasks
203+
204+
return result
205+
206+
207+
def item_would_exceed_max_marks(item, active_marks, max_tasks_by_mark):
208+
for mark in item.own_markers:
209+
if (
210+
mark.name in active_marks
211+
and active_marks[mark.name] >= max_tasks_by_mark[mark.name]
212+
):
213+
return True
214+
return False
215+
216+
217+
def item_add_active_marks(item, delta, active_marks):
218+
for mark in item.own_markers:
219+
if mark.name in active_marks:
220+
active_marks[mark.name] += delta
221+
222+
223+
def get_coro(task):
224+
if sys_version_info >= (3, 8):
225+
return task.get_coro()
226+
else:
227+
return task._coro
228+
229+
230+
def get_item_by_coro(task, item_by_coro):
231+
if isinstance(task, asyncio.Task):
232+
return item_by_coro[get_coro(task)]
233+
else:
234+
return item_by_coro[task]
235+
236+
174237
def _run_test_loop(tasks, session, run_tests):
175238
max_tasks = int(
176239
session.config.getoption("--max-asyncio-tasks")
177240
or session.config.getini("max_asyncio_tasks")
178241
)
179242

243+
if session.config.getoption("--max-asyncio-tasks-by-mark"):
244+
max_tasks_by_mark = parse_max_tasks_by_mark(
245+
session.config.getoption("--max-asyncio-tasks-by-mark"),
246+
"--max-asyncio-tasks-by-mark",
247+
)
248+
else:
249+
max_tasks_by_mark = parse_max_tasks_by_mark(
250+
session.config.getini("max_asyncio_tasks_by_mark"),
251+
"max_asyncio_tasks_by_mark",
252+
)
253+
180254
loop = asyncio.new_event_loop()
181255
try:
182-
task = run_tests(tasks, int(max_tasks))
256+
task = run_tests(tasks, int(max_tasks), max_tasks_by_mark)
183257
loop.run_until_complete(task)
184258
finally:
185259
loop.close()
@@ -234,15 +308,35 @@ def pytest_runtestloop(session):
234308
else:
235309
regular_items.append(item)
236310

237-
def get_coro(task):
238-
if sys_version_info >= (3, 8):
239-
return task.get_coro()
240-
else:
241-
return task._coro
311+
async def run_tests(tasks, max_tasks: int, max_tasks_by_mark):
312+
sidelined_tasks = tasks
313+
tasks = []
314+
active_marks = {mark: 0 for mark, max_tasks in max_tasks_by_mark.items()}
315+
316+
def enqueue_tasks():
317+
while len(tasks) < max_tasks:
318+
remove_index = None
319+
320+
for index, task in enumerate(sidelined_tasks):
321+
item = get_item_by_coro(task, item_by_coro)
322+
if not item_would_exceed_max_marks(
323+
item, active_marks, max_tasks_by_mark
324+
):
325+
item_add_active_marks(item, 1, active_marks)
326+
tasks.append(task)
327+
remove_index = index
328+
break
329+
330+
if remove_index == None:
331+
# No available tasks were found, give up control.
332+
break
333+
334+
# Removing element from start/middle of the list is actually
335+
# quite fast compared to iterating to the end of the list in
336+
# the above loop.
337+
sidelined_tasks.pop(remove_index)
242338

243-
async def run_tests(tasks, max_tasks: int):
244-
sidelined_tasks = tasks[max_tasks:]
245-
tasks = tasks[:max_tasks]
339+
enqueue_tasks()
246340

247341
task_timeout = int(
248342
session.config.getoption("--asyncio-task-timeout")
@@ -260,17 +354,16 @@ async def run_tests(tasks, max_tasks: int):
260354
# Mark when the task was started
261355
earliest_enqueue_time = time.time()
262356
for task in tasks:
263-
if isinstance(task, asyncio.Task):
264-
item = item_by_coro[get_coro(task)]
265-
else:
266-
item = item_by_coro[task]
357+
item = get_item_by_coro(task, item_by_coro)
267358
if not hasattr(item, "enqueue_time"):
268359
item.enqueue_time = time.time()
269360
earliest_enqueue_time = min(item.enqueue_time, earliest_enqueue_time)
270361

271362
time_to_wait = (time.time() - earliest_enqueue_time) - task_timeout
272363
done, pending = await asyncio.wait(
273-
tasks, return_when=asyncio.FIRST_COMPLETED, timeout=min(30, int(time_to_wait))
364+
tasks,
365+
return_when=asyncio.FIRST_COMPLETED,
366+
timeout=min(30, int(time_to_wait)),
274367
)
275368

276369
# Cancel tasks that have taken too long
@@ -291,6 +384,7 @@ async def run_tests(tasks, max_tasks: int):
291384

292385
for result in done:
293386
item = item_by_coro[get_coro(result)]
387+
item_add_active_marks(item, -1, active_marks)
294388

295389
# Flakey tests will be run again if they failed
296390
# TODO: add retry count
@@ -315,9 +409,7 @@ async def run_tests(tasks, max_tasks: int):
315409

316410
completed.append(result)
317411

318-
if sidelined_tasks:
319-
if len(tasks) < max_tasks:
320-
tasks.append(sidelined_tasks.pop(0))
412+
enqueue_tasks()
321413

322414
return completed
323415

tests/test_class_based.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
def test_class_based_tests(testdir):
2-
testdir.makepyfile("""
2+
testdir.makepyfile(
3+
"""
34
import pytest
45
56
@@ -8,15 +9,17 @@ class TestSuite:
89
async def test_cooperative(self):
910
assert True
1011
11-
""")
12+
"""
13+
)
1214

1315
result = testdir.runpytest()
1416

1517
result.assert_outcomes(passed=1)
1618

1719

1820
def test_class_based_tests_with_fixture(testdir):
19-
testdir.makepyfile("""
21+
testdir.makepyfile(
22+
"""
2023
import pytest
2124
2225
@@ -29,7 +32,8 @@ async def test_fixture(self):
2932
async def test_cooperative(self, test_fixture):
3033
assert test_fixture == "test_fixture"
3134
32-
""")
35+
"""
36+
)
3337

3438
result = testdir.runpytest()
3539

tests/test_fail.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ def test_a():
3030
result.assert_outcomes(failed=1)
3131

3232

33-
@pytest.mark.parametrize("dur1, dur2, expectedfails, expectedpasses", [
34-
(1.1, 2, 2, 0),
35-
(2, 2, 2, 0),
36-
])
33+
@pytest.mark.parametrize(
34+
"dur1, dur2, expectedfails, expectedpasses",
35+
[
36+
(1.1, 2, 2, 0),
37+
(2, 2, 2, 0),
38+
],
39+
)
3740
def test_function_takes_too_long(testdir, dur1, dur2, expectedfails, expectedpasses):
38-
testdir.makeconftest(
39-
"""""")
41+
testdir.makeconftest("""""")
4042

4143
testdir.makepyfile(
4244
"""
@@ -51,7 +53,9 @@ async def test_a():
5153
@pytest.mark.asyncio_cooperative
5254
async def test_b():
5355
await asyncio.sleep({})
54-
""".format(dur1, dur2)
56+
""".format(
57+
dur1, dur2
58+
)
5559
)
5660

5761
result = testdir.runpytest("--asyncio-task-timeout", "1")

tests/test_fixture.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,9 @@ async def test_ordering(ghi, _def, abc):
364364
result.assert_outcomes(passed=1)
365365

366366

367-
def test_ordering_of_fixtures_based_off_function_arguments_with_session_fixture(testdir):
367+
def test_ordering_of_fixtures_based_off_function_arguments_with_session_fixture(
368+
testdir,
369+
):
368370
testdir.makepyfile(
369371
"""
370372
import asyncio

0 commit comments

Comments
 (0)