From 0307c650973f832b61df0faabc141905741883d3 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 14 Oct 2025 10:27:56 -0700 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20use=20defaults=20for=20all?= =?UTF-8?q?=20params=20for=20full=5Fmodel=20marker?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/conftest.py | 29 ++++++++--------------------- tests/spyre_util.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 97308163..1863f952 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,8 @@ from llm_cache import (clear_llm_caches, get_cached_api_server, print_llm_cache_info) from llm_cache_util import SortKey, sort_tests_for_llm_caching -from spyre_util import (get_spyre_backend_list, get_spyre_model_list, +from spyre_util import (get_spyre_backend_list, get_spyre_cb_param, + get_spyre_model_list, get_spyre_tp_size, skip_unsupported_tp_size) from vllm.connections import global_http_connection from vllm.distributed import cleanup_dist_env_and_memory @@ -54,6 +55,10 @@ def pytest_generate_tests(metafunc): metafunc, existing_markers, ) + _add_param("tp_size", get_spyre_tp_size(num_tp=4), metafunc, + existing_markers) + _add_param("cb", get_spyre_cb_param(use_cb=True), metafunc, + existing_markers) _add_param( "warmup_shapes", [[(1024, 20, 4)]], @@ -69,6 +74,8 @@ def pytest_generate_tests(metafunc): metafunc, existing_markers, ) + _add_param("tp_size", get_spyre_tp_size(), metafunc, existing_markers) + _add_param("cb", get_spyre_cb_param(), metafunc, existing_markers) _add_param( "warmup_shapes", default_warmup_shape, @@ -91,26 +98,6 @@ def pytest_generate_tests(metafunc): existing_markers, ) - # TODO: add both these using _add_param too - # Will need to do some fancy stuff to add custom - # markers - if "cb" in metafunc.fixturenames and "cb" not in existing_markers: - metafunc.parametrize( - "cb", [pytest.param(1, marks=pytest.mark.cb, id="cb"), 0]) - - if "tp_size" in metafunc.fixturenames and \ - "tp_size" not in existing_markers: - metafunc.parametrize( - "tp_size", - [ - pytest.param(1), - pytest.param(2, marks=pytest.mark.multi), - pytest.param(4, marks=pytest.mark.multi), - pytest.param(8, marks=pytest.mark.multi), - ], - ids=lambda val: f"TP({val})", - ) - def _add_param(param_name: str, param_value, metafunc, existing_markers) -> None: diff --git a/tests/spyre_util.py b/tests/spyre_util.py index 788c5670..8dbbe762 100644 --- a/tests/spyre_util.py +++ b/tests/spyre_util.py @@ -213,6 +213,25 @@ def get_spyre_backend_list(): return backends +def get_spyre_cb_param(use_cb=False): + """Returns a list of pytest.params with cb set to True or both""" + if use_cb: + return [pytest.param(1, marks=pytest.mark.cb, id="cb")] + return [pytest.param(1, marks=pytest.mark.cb, id="cb"), 0] + + +def get_spyre_tp_size(num_tp=None): + """Returns a list of pytest.params with one or multiple tp sizes""" + if num_tp and num_tp > 1: + return [pytest.param(num_tp, marks=pytest.mark.multi)] + return [ + pytest.param(1), + pytest.param(2, marks=pytest.mark.multi), + pytest.param(4, marks=pytest.mark.multi), + pytest.param(8, marks=pytest.mark.multi), + ] + + # get model names from env, if not set then use default models for each type. # Multiple models can be specified with a comma separated list in # VLLM_SPYRE_TEST_MODEL_LIST