Skip to content

Commit 533d2a1

Browse files
rkooo567SangBin Cho
and
SangBin Cho
authored
[Typing] Mypy typing part 2 (vllm-project#4043)
Co-authored-by: SangBin Cho <[email protected]>
1 parent a532225 commit 533d2a1

20 files changed

+180
-126
lines changed

Diff for: .github/workflows/mypy.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ jobs:
4141
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
4242
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
4343
44+
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
45+
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
46+
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
47+
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
4448
# TODO(sang): Follow up
45-
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
46-
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
47-
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
48-
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
4949
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
5050

Diff for: format.sh

+4-4
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
104104
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
105105

106106
# TODO(sang): Follow up
107-
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
108-
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
109-
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
110-
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
107+
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
108+
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
109+
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
110+
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
111111
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
112112

113113

Diff for: vllm/engine/async_llm_engine.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import os
33
import time
44
from functools import partial
5-
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
6-
Set, Tuple, Type, Union)
5+
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
6+
Optional, Set, Tuple, Type, Union)
77

88
from transformers import PreTrainedTokenizer
99

@@ -52,7 +52,7 @@ class AsyncStream:
5252

5353
def __init__(self, request_id: str) -> None:
5454
self.request_id = request_id
55-
self._queue = asyncio.Queue()
55+
self._queue: asyncio.Queue = asyncio.Queue()
5656
self._finished = False
5757

5858
def put(self, item: Union[RequestOutput, Exception]) -> None:
@@ -312,15 +312,17 @@ def __init__(self,
312312
self.max_log_len = max_log_len
313313
self.engine = self._init_engine(*args, **kwargs)
314314

315-
self.background_loop = None
315+
self.background_loop: Optional[asyncio.Future] = None
316316
# We need to keep a reference to unshielded
317317
# task as well to prevent it from being garbage
318318
# collected
319-
self._background_loop_unshielded = None
319+
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
320320
self.start_engine_loop = start_engine_loop
321-
self._request_tracker: Optional[RequestTracker] = None
322321
self._errored_with: Optional[BaseException] = None
323322

323+
# Lazy initialized fields
324+
self._request_tracker: RequestTracker
325+
324326
@classmethod
325327
def from_engine_args(
326328
cls,
@@ -361,11 +363,13 @@ def from_engine_args(
361363
@property
362364
def is_running(self) -> bool:
363365
return (self.background_loop is not None
366+
and self._background_loop_unshielded is not None
364367
and not self._background_loop_unshielded.done())
365368

366369
@property
367370
def is_stopped(self) -> bool:
368-
return self.errored or (self.background_loop is not None
371+
return self.errored or (self.background_loop is not None and
372+
self._background_loop_unshielded is not None
369373
and self._background_loop_unshielded.done())
370374

371375
@property
@@ -381,7 +385,7 @@ def _error_callback(self, exc: Exception) -> None:
381385

382386
async def get_tokenizer(self) -> "PreTrainedTokenizer":
383387
if self.engine_use_ray:
384-
return await self.engine.get_tokenizer.remote()
388+
return await self.engine.get_tokenizer.remote() # type: ignore
385389
else:
386390
return self.engine.get_tokenizer()
387391

@@ -434,7 +438,8 @@ async def engine_step(self) -> bool:
434438
# TODO: Maybe add add_request_batch to reduce Ray overhead
435439
try:
436440
if self.engine_use_ray:
437-
await self.engine.add_request.remote(**new_request)
441+
await self.engine.add_request.remote( # type: ignore
442+
**new_request)
438443
else:
439444
await self.engine.add_request_async(**new_request)
440445
except ValueError as e:
@@ -449,7 +454,7 @@ async def engine_step(self) -> bool:
449454
await self._engine_abort(finished_requests)
450455

451456
if self.engine_use_ray:
452-
request_outputs = await self.engine.step.remote()
457+
request_outputs = await self.engine.step.remote() # type: ignore
453458
else:
454459
request_outputs = await self.engine.step_async()
455460

@@ -462,7 +467,7 @@ async def engine_step(self) -> bool:
462467

463468
async def _engine_abort(self, request_ids: Iterable[str]):
464469
if self.engine_use_ray:
465-
await self.engine.abort_request.remote(request_ids)
470+
await self.engine.abort_request.remote(request_ids) # type: ignore
466471
else:
467472
self.engine.abort_request(request_ids)
468473

@@ -525,11 +530,12 @@ async def add_request(
525530
arrival_time = time.time()
526531

527532
if self.engine_use_ray:
528-
prompt_token_ids = await self.engine.encode_request_async.remote(
529-
request_id=request_id,
530-
prompt=prompt,
531-
prompt_token_ids=prompt_token_ids,
532-
lora_request=lora_request)
533+
prompt_token_ids = await (
534+
self.engine.encode_request_async.remote( # type: ignore
535+
request_id=request_id,
536+
prompt=prompt,
537+
prompt_token_ids=prompt_token_ids,
538+
lora_request=lora_request))
533539
else:
534540
prompt_token_ids = await self.engine.encode_request_async(
535541
request_id=request_id,
@@ -676,13 +682,13 @@ def _abort(self, request_id: str) -> None:
676682
async def get_model_config(self) -> ModelConfig:
677683
"""Get the model configuration of the vLLM engine."""
678684
if self.engine_use_ray:
679-
return await self.engine.get_model_config.remote()
685+
return await self.engine.get_model_config.remote() # type: ignore
680686
else:
681687
return self.engine.get_model_config()
682688

683689
async def do_log_stats(self) -> None:
684690
if self.engine_use_ray:
685-
await self.engine.do_log_stats.remote()
691+
await self.engine.do_log_stats.remote() # type: ignore
686692
else:
687693
self.engine.do_log_stats()
688694

@@ -695,7 +701,7 @@ async def check_health(self) -> None:
695701

696702
if self.engine_use_ray:
697703
try:
698-
await self.engine.check_health.remote()
704+
await self.engine.check_health.remote() # type: ignore
699705
except ray.exceptions.RayActorError as e:
700706
raise RuntimeError("Engine is dead.") from e
701707
else:

Diff for: vllm/lora/worker_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ def create_lora_manager(
107107
self._lora_manager: LoRAModelManager = lora_manager
108108
return lora_manager.model
109109

110-
def set_active_loras(self, lora_requests: List[LoRARequest],
110+
def set_active_loras(self, lora_requests: Set[LoRARequest],
111111
lora_mapping: LoRAMapping) -> None:
112112
self._apply_loras(lora_requests)
113113
self._lora_manager.set_lora_mapping(lora_mapping)
114114

115-
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
115+
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
116116
loras_that_exist = self.list_loras()
117117
loras_map = {
118118
lora_request.lora_int_id: lora_request

Diff for: vllm/model_executor/guided_decoding/outlines_decoding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class GuidedDecodingMode(Enum):
5555

5656
async def get_outlines_guided_decoding_logits_processor(
5757
request: Union[CompletionRequest, ChatCompletionRequest],
58-
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
58+
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
5959
"""
6060
Given an OpenAI-compatible request, check for guided decoding parameters
6161
and get the necessary logits processor for the given guide.
@@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor(
8484

8585
def _get_guide_and_mode(
8686
request: Union[CompletionRequest, ChatCompletionRequest]
87-
) -> Tuple[str, GuidedDecodingMode]:
87+
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
8888

8989
if request.guided_json:
9090
json = request.guided_json

Diff for: vllm/model_executor/guided_decoding/outlines_logits_processors.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,18 @@
2121
from typing import Callable, DefaultDict, Dict, List, Optional, Union
2222

2323
import torch
24-
from outlines.fsm.fsm import CFGFSM, RegexFSM
24+
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
2525
from outlines.fsm.json_schema import build_regex_from_schema
2626
from pydantic import BaseModel
2727
from transformers import PreTrainedTokenizerBase
2828

2929

3030
class BaseLogitsProcessor:
3131

32+
def __init__(self):
33+
# Child class should use initialize in their init.
34+
self.fsm: FSM
35+
3236
def init_state(self):
3337
"""Initialize the FSM states."""
3438
self.fsm_state: DefaultDict[int, int] = defaultdict(int)

Diff for: vllm/model_executor/model_loader/neuron.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Utilities for selecting and loading neuron models."""
22
import importlib
33
import os
4-
from typing import Optional, Type
4+
from typing import Dict, Optional, Tuple
55

66
import torch
77
import torch.nn as nn
@@ -27,7 +27,7 @@
2727
}
2828

2929
# Models supported by Neuron.
30-
_NEURON_SUPPORTED_MODELS = {
30+
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
3131
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
3232
"LlamaForSampling", "LlamaForCausalLM"),
3333
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
@@ -43,11 +43,13 @@ def __init__(
4343
) -> None:
4444
super().__init__()
4545
self.config = config
46-
self.model = None
4746
self.logits_processor = LogitsProcessor(config.vocab_size,
4847
logits_as_input=True)
4948
self.sampler = Sampler()
5049

50+
# Lazy initialized
51+
self.model: nn.Module
52+
5153
def forward(
5254
self,
5355
input_ids: torch.Tensor,
@@ -74,17 +76,17 @@ def sample(
7476

7577
def load_weights(self, model_name_or_path: str, **kwargs):
7678
arch = _get_model_architecture(self.config)
77-
neuronx_module_path, neuronx_model_cls, hf_model_cls = (
79+
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
7880
_NEURON_SUPPORTED_MODELS[arch])
7981
neuronx_module = importlib.import_module(neuronx_module_path)
80-
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
82+
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
8183

8284
split_model_dir = f"{model_name_or_path}-split"
8385
if os.path.isdir(os.path.join(model_name_or_path,
8486
"pytorch_model.bin")):
8587
split_model_dir = model_name_or_path
8688
elif not os.path.exists(f"{model_name_or_path}-split"):
87-
hf_model_cls = getattr(transformers, hf_model_cls)
89+
hf_model_cls = getattr(transformers, hf_model_cls_name)
8890
from transformers_neuronx.module import save_pretrained_split
8991

9092
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
@@ -96,7 +98,7 @@ def load_weights(self, model_name_or_path: str, **kwargs):
9698
self.model.to_neuron()
9799

98100

99-
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
101+
def _get_model_architecture(config: PretrainedConfig) -> str:
100102
architectures = getattr(config, "architectures", [])
101103
for arch in architectures:
102104
if arch in _NEURON_SUPPORTED_MODELS:

Diff for: vllm/model_executor/model_loader/tensorizer.py

+1
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def __post_init__(self):
167167
decryption_params = DecryptionParams.from_key(key)
168168
self.deserializer_params['encryption'] = decryption_params
169169

170+
@staticmethod
170171
def add_cli_args(
171172
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
172173
"""Tensorizer CLI arguments"""

Diff for: vllm/model_executor/sampling_metadata.py

+4
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def from_sampling_metadata(
113113
get_num_triton_sampler_splits(vocab_size))
114114

115115
sample_indices_start_idx = 0
116+
assert sampling_metadata.seq_groups is not None
117+
assert sampling_metadata.seq_data is not None
116118
for i, seq_group in enumerate(sampling_metadata.seq_groups):
117119
seq_ids, sampling_params = seq_group
118120
temperature = sampling_params.temperature
@@ -147,6 +149,7 @@ def from_sampling_metadata(
147149
and sampling_params.prompt_logprobs is not None):
148150
# For tokens in the prompt that we only need to get
149151
# their logprobs
152+
assert sampling_metadata.prompt_lens is not None
150153
prompt_len = sampling_metadata.prompt_lens[i]
151154
temperatures += [temperature] * (prompt_len - 1)
152155
top_ps += [top_p] * (prompt_len - 1)
@@ -172,6 +175,7 @@ def from_sampling_metadata(
172175
is_prompt = i < sampling_metadata.num_prompts
173176
if is_prompt:
174177
prompt_best_of.append(sampling_params.best_of)
178+
assert sampling_metadata.prompt_lens is not None
175179
prompt_len = sampling_metadata.prompt_lens[i]
176180

177181
if sampling_params.prompt_logprobs is not None:

Diff for: vllm/spec_decode/batch_expansion.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def score_proposals(
106106
def _expand_batch(
107107
self,
108108
seq_group_metadata_list: List[SequenceGroupMetadata],
109-
proposal_token_ids_list: List[TokenId],
109+
proposal_token_ids_list: List[List[TokenId]],
110110
proposal_lens_list: List[int],
111111
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
112112
"""Given the input sequences and potentially multiple corresponding
@@ -218,7 +218,7 @@ def _create_scoring_model_input(
218218
def _create_target_seq_group_metadata(
219219
self,
220220
input_seq_group_metadata: SequenceGroupMetadata,
221-
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
221+
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
222222
batch_index: int,
223223
target_seq_ids_iter: Iterator[TargetSeqId],
224224
) -> List[SequenceGroupMetadata]:
@@ -360,7 +360,7 @@ def _get_token_ids_to_score(
360360
[0, 1, 2]
361361
[0, 1, 2, 3]
362362
"""
363-
empty_token_ids = []
363+
empty_token_ids: List[TokenId] = []
364364

365365
token_ids_to_score = [empty_token_ids]
366366
token_ids_to_score.extend([

Diff for: vllm/spec_decode/interfaces.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Dict, List, Optional, Tuple
3+
from typing import Dict, List, Optional
44

55
import torch
66

@@ -73,5 +73,5 @@ def score_proposals(
7373
blocks_to_copy: Optional[Dict[int, List[int]]],
7474
k: int,
7575
proposals: SpeculativeProposals,
76-
) -> Tuple[torch.Tensor, torch.Tensor]:
76+
) -> SpeculativeScores:
7777
raise NotImplementedError

Diff for: vllm/spec_decode/metrics.py

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
112112
113113
Returns a CUDA event recording when the copy is complete.
114114
"""
115+
assert self._copy_stream is not None
115116
self._copy_stream.wait_stream(torch.cuda.current_stream())
116117

117118
with torch.cuda.stream(self._copy_stream):

0 commit comments

Comments
 (0)