Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 8 additions & 21 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from llm_cache import (clear_llm_caches, get_cached_api_server,
print_llm_cache_info)
from llm_cache_util import SortKey, sort_tests_for_llm_caching
from spyre_util import (get_spyre_backend_list, get_spyre_model_list,
from spyre_util import (get_spyre_backend_list, get_spyre_cb_param,
get_spyre_model_list, get_spyre_tp_size,
skip_unsupported_tp_size)
from vllm.connections import global_http_connection
from vllm.distributed import cleanup_dist_env_and_memory
Expand Down Expand Up @@ -54,6 +55,10 @@ def pytest_generate_tests(metafunc):
metafunc,
existing_markers,
)
_add_param("tp_size", get_spyre_tp_size(num_tp=4), metafunc,
existing_markers)
_add_param("cb", get_spyre_cb_param(use_cb=True), metafunc,
existing_markers)
_add_param(
"warmup_shapes",
[[(1024, 20, 4)]],
Expand All @@ -69,6 +74,8 @@ def pytest_generate_tests(metafunc):
metafunc,
existing_markers,
)
_add_param("tp_size", get_spyre_tp_size(), metafunc, existing_markers)
_add_param("cb", get_spyre_cb_param(), metafunc, existing_markers)
_add_param(
"warmup_shapes",
default_warmup_shape,
Expand All @@ -91,26 +98,6 @@ def pytest_generate_tests(metafunc):
existing_markers,
)

# TODO: add both these using _add_param too
# Will need to do some fancy stuff to add custom
# markers
if "cb" in metafunc.fixturenames and "cb" not in existing_markers:
metafunc.parametrize(
"cb", [pytest.param(1, marks=pytest.mark.cb, id="cb"), 0])

if "tp_size" in metafunc.fixturenames and \
"tp_size" not in existing_markers:
metafunc.parametrize(
"tp_size",
[
pytest.param(1),
pytest.param(2, marks=pytest.mark.multi),
pytest.param(4, marks=pytest.mark.multi),
pytest.param(8, marks=pytest.mark.multi),
],
ids=lambda val: f"TP({val})",
)


def _add_param(param_name: str, param_value, metafunc,
existing_markers) -> None:
Expand Down
19 changes: 19 additions & 0 deletions tests/spyre_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,25 @@ def get_spyre_backend_list():
return backends


def get_spyre_cb_param(use_cb=False):
"""Returns a list of pytest.params with cb set to True or both"""
if use_cb:
return [pytest.param(1, marks=pytest.mark.cb, id="cb")]
return [pytest.param(1, marks=pytest.mark.cb, id="cb"), 0]


def get_spyre_tp_size(num_tp=None):
"""Returns a list of pytest.params with one or multiple tp sizes"""
if num_tp and num_tp > 1:
return [pytest.param(num_tp, marks=pytest.mark.multi)]
return [
pytest.param(1),
pytest.param(2, marks=pytest.mark.multi),
pytest.param(4, marks=pytest.mark.multi),
pytest.param(8, marks=pytest.mark.multi),
]


# get model names from env, if not set then use default models for each type.
# Multiple models can be specified with a comma separated list in
# VLLM_SPYRE_TEST_MODEL_LIST
Expand Down