Skip to content

Commit 7671d57

Browse files
authored
Make sglang tests use model_management.py and server_management.py (#878)
This is a continuation of #728 . Next step: make hf download share the hf dataset download code in sharktank and make sure caching works.
1 parent 1e06fd8 commit 7671d57

File tree

10 files changed

+194
-348
lines changed

10 files changed

+194
-348
lines changed

app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414
sys.path.append(
1515
os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
1616
)
17-
from integration_tests.llm.utils import (
18-
compile_model,
19-
end_log_group,
20-
export_paged_llm_v1,
21-
download_with_hf_datasets,
22-
start_log_group,
17+
from integration_tests.llm.model_management import (
18+
ModelConfig,
19+
ModelProcessor,
20+
ModelSource,
2321
)
22+
from integration_tests.llm.logging_utils import start_log_group, end_log_group
2423

2524
logger = logging.getLogger(__name__)
2625

@@ -47,16 +46,19 @@ def pre_process_model(request, tmp_path_factory):
4746
settings = request.param["settings"]
4847
batch_sizes = request.param["batch_sizes"]
4948

50-
mlir_path = tmp_dir / "model.mlir"
51-
config_path = tmp_dir / "config.json"
52-
vmfb_path = tmp_dir / "model.vmfb"
53-
54-
model_path = tmp_dir / model_param_file_name
55-
download_with_hf_datasets(tmp_dir, model_name)
56-
57-
export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)
49+
# Configure model
50+
config = ModelConfig(
51+
model_file=model_param_file_name,
52+
tokenizer_id=model_name, # Using model_name as tokenizer_id, adjust if needed
53+
batch_sizes=batch_sizes,
54+
device_settings=settings,
55+
source=ModelSource.HUGGINGFACE,
56+
repo_id=model_name, # Using model_name as repo_id, adjust if needed
57+
)
5858

59-
compile_model(mlir_path, vmfb_path, settings)
59+
# Process model through all stages
60+
processor = ModelProcessor(tmp_dir)
61+
artifacts = processor.process_model(config)
6062

6163
logger.info("Model artifacts setup successfully" + end_log_group())
6264
MODEL_DIR_CACHE[param_key] = tmp_dir

app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
from sglang import bench_serving
1414

1515
from .utils import SGLangBenchmarkArgs, log_jsonl_result
16-
17-
from integration_tests.llm.utils import download_tokenizer, wait_for_server
16+
from integration_tests.llm.model_management import (
17+
ModelConfig,
18+
ModelProcessor,
19+
ModelSource,
20+
)
21+
from integration_tests.llm.server_management import ServerInstance
1822

1923
logger = logging.getLogger(__name__)
2024

@@ -26,17 +30,31 @@
2630
def test_sglang_benchmark(request_rate, tokenizer_id, sglang_args, tmp_path_factory):
2731
tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")
2832

29-
# Download tokenizer for llama3_8B_fp16
30-
download_tokenizer(tmp_dir, tokenizer_id)
33+
# Download tokenizer using ModelProcessor
34+
config = ModelConfig(
35+
model_file="tokenizer.json", # Only need tokenizer
36+
tokenizer_id=tokenizer_id,
37+
batch_sizes=(1,), # Not relevant for tokenizer only
38+
device_settings=None, # Not relevant for tokenizer only
39+
source=ModelSource.HUGGINGFACE,
40+
repo_id=tokenizer_id,
41+
)
42+
processor = ModelProcessor(tmp_dir)
43+
artifacts = processor.process_model(config)
3144

3245
logger.info("Beginning SGLang benchmark test...")
3346

3447
port = sglang_args
3548
base_url = f"http://localhost:{port}"
3649

37-
# Setting a high timeout gives enough time for downloading model artifacts
38-
# and starting up server... Takes a little longer than shortfin.
39-
wait_for_server(base_url, timeout=600)
50+
# Wait for server using ServerInstance's method
51+
server = ServerInstance(
52+
None
53+
) # We don't need config since we're just using wait_for_ready
54+
server.port = int(port) # Set port manually since we didn't start the server
55+
server.wait_for_ready(
56+
timeout=600
57+
) # High timeout for model artifacts download and server startup
4058

4159
benchmark_args = SGLangBenchmarkArgs(
4260
backend="sglang",

app_tests/benchmark_tests/llm/sglang_benchmarks/shortfin_benchmark_test.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@
1919
log_jsonl_result,
2020
)
2121

22-
from integration_tests.llm.utils import (
23-
end_log_group,
24-
find_available_port,
25-
start_llm_server,
26-
start_log_group,
27-
)
22+
from integration_tests.llm.logging_utils import end_log_group, start_log_group
23+
from integration_tests.llm.server_management import ServerConfig, ServerInstance
24+
from integration_tests.llm.model_management import ModelArtifacts
2825

2926
logger = logging.getLogger(__name__)
3027

@@ -83,20 +80,24 @@ def test_shortfin_benchmark(
8380
model_path = tmp_dir / model_param_file_name
8481

8582
# Start shortfin llm server
86-
server_process, port = start_llm_server(
87-
tokenizer_path,
88-
config_path,
89-
vmfb_path,
90-
model_path,
91-
device_settings,
92-
timeout=30,
83+
server_config = ServerConfig(
84+
artifacts=ModelArtifacts(
85+
weights_path=model_path,
86+
tokenizer_path=tokenizer_path,
87+
mlir_path=tmp_dir / "model.mlir",
88+
vmfb_path=vmfb_path,
89+
config_path=config_path,
90+
),
91+
device_settings=device_settings,
9392
)
93+
server = ServerInstance(server_config)
94+
server.start()
9495

9596
# Run and collect SGLang Serving Benchmark
9697
benchmark_args = SGLangBenchmarkArgs(
9798
backend="shortfin",
9899
num_prompt=10,
99-
base_url=f"http://localhost:{port}",
100+
base_url=f"http://localhost:{server.port}",
100101
tokenizer=tmp_dir,
101102
request_rate=request_rate,
102103
)
@@ -130,5 +131,4 @@ def test_shortfin_benchmark(
130131
except Exception as e:
131132
logger.error(e)
132133

133-
server_process.terminate()
134-
server_process.wait()
134+
server.stop()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import os
8+
9+
10+
def start_log_group(headline):
11+
"""Start a collapsible log group in GitHub Actions."""
12+
if os.environ.get("GITHUB_ACTIONS") == "true":
13+
return f"\n::group::{headline}"
14+
return ""
15+
16+
17+
def end_log_group():
18+
"""End a collapsible log group in GitHub Actions."""
19+
if os.environ.get("GITHUB_ACTIONS") == "true":
20+
return "\n::endgroup::"
21+
return ""

app_tests/integration_tests/llm/model_management.py

+63-12
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,20 @@
33
from pathlib import Path
44
import subprocess
55
from dataclasses import dataclass
6-
from typing import Optional, Tuple
6+
from typing import Optional, Tuple, Dict
77
from enum import Enum, auto
88

9+
from sharktank.utils.hf_datasets import Dataset, RemoteFile, get_dataset
10+
911
logger = logging.getLogger(__name__)
1012

1113

14+
class AccuracyValidationException(RuntimeError):
15+
"""Exception raised when accuracy validation fails."""
16+
17+
pass
18+
19+
1220
class ModelSource(Enum):
1321
HUGGINGFACE = auto()
1422
LOCAL = auto()
@@ -34,13 +42,17 @@ class ModelConfig:
3442
batch_sizes: Tuple[int, ...]
3543
device_settings: "DeviceSettings"
3644
source: ModelSource
45+
dataset_name: Optional[str] = None # Name of the dataset in hf_datasets.py
3746
repo_id: Optional[str] = None
3847
local_path: Optional[Path] = None
3948
azure_config: Optional[AzureConfig] = None
4049

4150
def __post_init__(self):
42-
if self.source == ModelSource.HUGGINGFACE and not self.repo_id:
43-
raise ValueError("repo_id required for HuggingFace models")
51+
if self.source == ModelSource.HUGGINGFACE:
52+
if not (self.dataset_name or self.repo_id):
53+
raise ValueError(
54+
"Either dataset_name or repo_id required for HuggingFace models"
55+
)
4456
elif self.source == ModelSource.LOCAL and not self.local_path:
4557
raise ValueError("local_path required for local models")
4658
elif self.source == ModelSource.AZURE and not self.azure_config:
@@ -70,6 +82,8 @@ def __init__(self, base_dir: Path, config: ModelConfig):
7082
def _get_model_dir(self) -> Path:
7183
"""Creates and returns appropriate model directory based on source."""
7284
if self.config.source == ModelSource.HUGGINGFACE:
85+
if self.config.dataset_name:
86+
return self.base_dir / self.config.dataset_name.replace("/", "_")
7387
return self.base_dir / self.config.repo_id.replace("/", "_")
7488
elif self.config.source == ModelSource.LOCAL:
7589
return self.base_dir / "local" / self.config.local_path.stem
@@ -82,15 +96,36 @@ def _get_model_dir(self) -> Path:
8296
raise ValueError(f"Unsupported model source: {self.config.source}")
8397

8498
def _download_from_huggingface(self) -> Path:
85-
"""Downloads model from HuggingFace."""
99+
"""Downloads model from HuggingFace using hf_datasets.py."""
86100
model_path = self.model_dir / self.config.model_file
87101
if not model_path.exists():
88-
logger.info(f"Downloading model {self.config.repo_id} from HuggingFace")
89-
subprocess.run(
90-
f"huggingface-cli download --local-dir {self.model_dir} {self.config.repo_id} {self.config.model_file}",
91-
shell=True,
92-
check=True,
93-
)
102+
if self.config.dataset_name:
103+
logger.info(
104+
f"Downloading model {self.config.dataset_name} using hf_datasets"
105+
)
106+
dataset = get_dataset(self.config.dataset_name)
107+
downloaded_files = dataset.download(local_dir=self.model_dir)
108+
109+
# Find the model file in downloaded files
110+
for file_id, paths in downloaded_files.items():
111+
for path in paths:
112+
if path.name == self.config.model_file:
113+
return path
114+
115+
raise ValueError(
116+
f"Model file {self.config.model_file} not found in dataset {self.config.dataset_name}"
117+
)
118+
else:
119+
logger.info(f"Downloading model {self.config.repo_id} from HuggingFace")
120+
# Create a temporary dataset for direct repo downloads
121+
remote_file = RemoteFile(
122+
file_id="model",
123+
repo_id=self.config.repo_id,
124+
filename=self.config.model_file,
125+
)
126+
downloaded_paths = remote_file.download(local_dir=self.model_dir)
127+
return downloaded_paths[0]
128+
94129
return model_path
95130

96131
def _copy_from_local(self) -> Path:
@@ -132,14 +167,30 @@ def _download_from_azure(self) -> Path:
132167
return model_path
133168

134169
def prepare_tokenizer(self) -> Path:
135-
"""Downloads and prepares tokenizer."""
170+
"""Downloads and prepares tokenizer using hf_datasets.py when possible."""
136171
tokenizer_path = self.model_dir / "tokenizer.json"
172+
137173
if not tokenizer_path.exists():
138-
logger.info(f"Downloading tokenizer {self.config.tokenizer_id}")
174+
# First try to get tokenizer from dataset if available
175+
if self.config.dataset_name:
176+
dataset = get_dataset(self.config.dataset_name)
177+
downloaded_files = dataset.download(local_dir=self.model_dir)
178+
179+
# Look for tokenizer files in downloaded files
180+
for file_id, paths in downloaded_files.items():
181+
for path in paths:
182+
if path.name == "tokenizer.json":
183+
return path
184+
185+
# Fall back to downloading from transformers if not found in dataset
186+
logger.info(
187+
f"Downloading tokenizer {self.config.tokenizer_id} using transformers"
188+
)
139189
from transformers import AutoTokenizer
140190

141191
tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_id)
142192
tokenizer.save_pretrained(self.model_dir)
193+
143194
return tokenizer_path
144195

145196
def export_model(self, weights_path: Path) -> Tuple[Path, Path]:

app_tests/integration_tests/llm/server_management.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def start(self) -> None:
8181
self.process = subprocess.Popen(cmd)
8282
self.wait_for_ready()
8383

84-
def wait_for_ready(self, timeout: int = 10) -> None:
84+
def wait_for_ready(self, timeout: int = 30) -> None:
8585
"""Waits for server to be ready and responding to health checks."""
8686
if self.port is None:
8787
raise RuntimeError("Server hasn't been started")

0 commit comments

Comments
 (0)