@@ -27,6 +27,32 @@ def pytest_addoption(parser):
2727 default = 100 ,
2828 )
2929
30+ parser .addoption (
31+ "--max-asyncio-tasks-by-mark" ,
32+ action = "store" ,
33+ default = None ,
34+ help = "asyncio: maximum number of tasks to run concurrently for each mark (mark1,mark2,..=int pairs, space separated)" ,
35+ )
36+ parser .addini (
37+ "max_asyncio_tasks_by_mark" ,
38+ "asyncio: asyncio: maximum number of tasks to run concurrently for each mark (mark1,mark2,..=int pairs, space separated)" ,
39+ default = None ,
40+ )
41+
42+ parser .addoption (
43+ "--max-asyncio-tasks-by-mark-remainder" ,
44+ action = "store" ,
45+ default = 0 ,
46+ help = "asyncio: maximum number of tasks to run concurrently for tasks that didn't match any "
47+ "marks in `--max-asyncio-tasks-by-mark` (default 0, unlimited)" ,
48+ )
49+ parser .addini (
50+ "max_asyncio_tasks_by_mark_remainder" ,
51+ help = "asyncio: maximum number of tasks to run concurrently for tasks that didn't match any "
52+ "marks in `--max-asyncio-tasks-by-mark` (default 0, unlimited)" ,
53+ default = 0 ,
54+ )
55+
3056 parser .addoption (
3157 "--asyncio-task-timeout" ,
3258 action = "store" ,
@@ -171,15 +197,128 @@ def item_to_task(item):
171197 raise NotCoroutine
172198
173199
200+ class MarkLimits :
201+ # To support multiple marks in the same group separate by comma, a shared group object
202+ # is created and inserted into the groups dictionary multiple times for each mark in
203+ # the group. When a test has multiple marks that belong to the same group, each mark
204+ # will be added multiple times to the set, but the result will be that the number
205+ # of active items in the group will only increase by 1. A simple ref count is not enough
206+ # in this situation, because it would overcount how many items are actively running.
207+ class Group :
208+ def __init__ (self , max ):
209+ self .max = max
210+ self .items = set ()
211+
212+ def __init__ (self ):
213+ self .groups = {}
214+ self .remainder = 0
215+ self .remainder_max = 0
216+
217+ def update_config (self , session ):
218+ if session .config .getoption ("--max-asyncio-tasks-by-mark" ):
219+ self .groups = MarkLimits .parse_max_tasks_by_mark (
220+ session .config .getoption ("--max-asyncio-tasks-by-mark" ),
221+ "--max-asyncio-tasks-by-mark" ,
222+ )
223+ else :
224+ self .groups = MarkLimits .parse_max_tasks_by_mark (
225+ session .config .getini ("max_asyncio_tasks_by_mark" ),
226+ "max_asyncio_tasks_by_mark" ,
227+ )
228+
229+ self .remainder_max = int (
230+ session .config .getoption ("--max-asyncio-tasks-by-mark-remainder" )
231+ or session .config .getini ("max_asyncio_tasks_by_mark_remainder" )
232+ )
233+
234+ @staticmethod
235+ def parse_max_tasks_by_mark (config_value , debug_prefix ):
236+ if not config_value :
237+ return {}
238+
239+ result = {}
240+ pairs = config_value .split ()
241+ for pair in pairs :
242+ columns = pair .split ("=" )
243+ if len (columns ) > 2 :
244+ assert False , f"`{ debug_prefix } `: too many `=` in `{ pair } `"
245+
246+ try :
247+ max_tasks = int (columns [1 ])
248+ except ValueError :
249+ assert False , f"`{ debug_prefix } `: expected integer in `{ pair } `"
250+
251+ group = MarkLimits .Group (max_tasks )
252+ for mark in columns [0 ].split ("," ):
253+ assert (
254+ mark not in result
255+ ), f"`{ debug_prefix } `: multiple occurences of mark `{ mark } `"
256+ result [mark ] = group
257+
258+ return result
259+
260+ def would_exceed_max_marks (self , item ):
261+ for mark in item .own_markers :
262+ group = self .groups .get (mark .name )
263+ if group and len (group .items ) >= group .max :
264+ return True
265+
266+ if self .remainder_max and self .remainder >= self .remainder_max :
267+ return True
268+
269+ return False
270+
271+ def update_active_marks (self , item , add ):
272+ matched_mark = False
273+
274+ for mark in item .own_markers :
275+ group = self .groups .get (mark .name )
276+ if group :
277+ # Not returning here, because each mark must be accounted for,
278+ # not just the first matching one.
279+ matched_mark = True
280+ if add :
281+ group .items .add (item )
282+ elif item in group .items :
283+ group .items .remove (item )
284+
285+ # Make sure not to overcount the remainder
286+ if matched_mark :
287+ return
288+
289+ if self .remainder_max :
290+ if add :
291+ self .remainder += 1
292+ else :
293+ self .remainder -= 1
294+
295+
296+ def get_coro (task ):
297+ if sys_version_info >= (3 , 8 ):
298+ return task .get_coro ()
299+ else :
300+ return task ._coro
301+
302+
303+ def get_item_by_coro (task , item_by_coro ):
304+ if isinstance (task , asyncio .Task ):
305+ return item_by_coro [get_coro (task )]
306+ else :
307+ return item_by_coro [task ]
308+
309+
174310def _run_test_loop (tasks , session , run_tests ):
175311 max_tasks = int (
176312 session .config .getoption ("--max-asyncio-tasks" )
177313 or session .config .getini ("max_asyncio_tasks" )
178314 )
179315
316+ mark_limits = MarkLimits ()
317+ mark_limits .update_config (session )
318+
180319 loop = asyncio .new_event_loop ()
181320 try :
182- task = run_tests (tasks , int (max_tasks ))
321+ task = run_tests (tasks , int (max_tasks ), mark_limits )
183322 loop .run_until_complete (task )
184323 finally :
185324 loop .close ()
@@ -234,15 +373,32 @@ def pytest_runtestloop(session):
234373 else :
235374 regular_items .append (item )
236375
237- def get_coro (task ):
238- if sys_version_info >= (3 , 8 ):
239- return task .get_coro ()
240- else :
241- return task ._coro
376+ async def run_tests (tasks , max_tasks : int , mark_limits ):
377+ sidelined_tasks = tasks
378+ tasks = []
379+
380+ def enqueue_tasks ():
381+ while len (tasks ) < max_tasks :
382+ remove_index = None
383+
384+ for index , task in enumerate (sidelined_tasks ):
385+ item = get_item_by_coro (task , item_by_coro )
386+ if not mark_limits .would_exceed_max_marks (item ):
387+ mark_limits .update_active_marks (item , add = True )
388+ tasks .append (task )
389+ remove_index = index
390+ break
391+
392+ if remove_index == None :
393+ # No available tasks were found, give up control.
394+ break
395+
396+ # Removing element from start/middle of the list is actually
397+ # quite fast compared to iterating to the end of the list in
398+ # the above loop.
399+ sidelined_tasks .pop (remove_index )
242400
243- async def run_tests (tasks , max_tasks : int ):
244- sidelined_tasks = tasks [max_tasks :]
245- tasks = tasks [:max_tasks ]
401+ enqueue_tasks ()
246402
247403 task_timeout = int (
248404 session .config .getoption ("--asyncio-task-timeout" )
@@ -260,10 +416,7 @@ async def run_tests(tasks, max_tasks: int):
260416 # Mark when the task was started
261417 earliest_enqueue_time = time .time ()
262418 for task in tasks :
263- if isinstance (task , asyncio .Task ):
264- item = item_by_coro [get_coro (task )]
265- else :
266- item = item_by_coro [task ]
419+ item = get_item_by_coro (task , item_by_coro )
267420 if not hasattr (item , "enqueue_time" ):
268421 item .enqueue_time = time .time ()
269422 earliest_enqueue_time = min (item .enqueue_time , earliest_enqueue_time )
@@ -293,6 +446,7 @@ async def run_tests(tasks, max_tasks: int):
293446
294447 for result in done :
295448 item = item_by_coro [get_coro (result )]
449+ mark_limits .update_active_marks (item , add = False )
296450
297451 # Flakey tests will be run again if they failed
298452 # TODO: add retry count
@@ -317,9 +471,7 @@ async def run_tests(tasks, max_tasks: int):
317471
318472 completed .append (result )
319473
320- if sidelined_tasks :
321- if len (tasks ) < max_tasks :
322- tasks .append (sidelined_tasks .pop (0 ))
474+ enqueue_tasks ()
323475
324476 return completed
325477
0 commit comments