Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/format_pr_body.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:

- name: Get vLLM version
run: |
VLLM_COMMIT=ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
VLLM_COMMIT=aaddc9c82a6df73f0f93912d3aee987859d28a53
echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> $GITHUB_ENV

- name: Checkout repository
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/vllm_ascend_test_pr_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
name: e2e-full
strategy:
matrix:
vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0]
vllm_version: [aaddc9c82a6df73f0f93912d3aee987859d28a53, v0.12.0]
needs: [changes]
if: ${{ needs.changes.outputs.e2e_tracker == 'true' }}
uses: ./.github/workflows/_e2e_test.yaml
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/vllm_ascend_test_pr_light.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
lint:
uses: ./.github/workflows/pre-commit.yml
with:
vllm: ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
vllm: aaddc9c82a6df73f0f93912d3aee987859d28a53
changes:
runs-on: ubuntu-latest
outputs:
Expand Down Expand Up @@ -85,7 +85,7 @@ jobs:
SOC_VERSION: ascend910b1
strategy:
matrix:
vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0]
vllm_version: [aaddc9c82a6df73f0f93912d3aee987859d28a53, v0.12.0]

steps:
- name: Free up disk space
Expand Down Expand Up @@ -152,7 +152,7 @@ jobs:
name: e2e-light
strategy:
matrix:
vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0]
vllm_version: [aaddc9c82a6df73f0f93912d3aee987859d28a53, v0.12.0]
# Note (yikun): If CI resource are limited we can split job into two chain jobs
needs: [lint, changes]
# only trigger e2e test after lint passed and the change is e2e related with pull request.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/community/versioning_policy.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ The table below is the release compatibility matrix for vLLM Ascend release.
For main branch of vLLM Ascend, we usually make it compatible with the latest vLLM release and a newer commit hash of vLLM. Please note that this table is usually updated. Please check it regularly.
| vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu |
|-------------|--------------|------------------|-------------|--------------------|
| main | ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0 tag | >= 3.10, < 3.12 | 8.3.RC2 | 2.8.0 / 2.8.0 |
| main | aaddc9c82a6df73f0f93912d3aee987859d28a53, v0.12.0 tag | >= 3.10, < 3.12 | 8.3.RC2 | 2.8.0 / 2.8.0 |

## Release cadence

Expand Down
Empty file added vllm_ascend/pool/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions vllm_ascend/pool/medatata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch


class PoolingStates:
# NOTE: This should be removed after we drop support of vLLM v0.12.0
def __init__(self):
# for chunked prefill with ALL pooling
self.hidden_states_cache: list[torch.Tensor] = []

def clean(self):
self.hidden_states_cache.clear()
103 changes: 70 additions & 33 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.pool.medatata import PoolingStates
from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.spec_decode import get_spec_decode_method
Expand All @@ -146,7 +147,7 @@
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendDeviceType, ProfileExecuteDuration,
enable_sp, get_ascend_device_type, is_enable_nz,
is_moe_model, lmhead_tp_enable)
is_moe_model, lmhead_tp_enable, vllm_version_is)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch

if TYPE_CHECKING:
Expand Down Expand Up @@ -2134,33 +2135,37 @@ def _pool(
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]] = None,
finished_recving: Optional[set[str]] = None,
kv_connector_output: Optional["KVConnectorOutput"] = None,
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
"Either all or none of the requests in" \
" a batch must be pooling request"
assert self.input_batch.num_reqs == len(
self.input_batch.pooling_params
), ("Either all or none of the requests in a batch must be pooling request"
)

hidden_states = hidden_states[:num_scheduled_tokens]
pooling_metadata = self.input_batch.pooling_metadata
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
device=hidden_states.device)
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]

pooling_metadata = self.input_batch.get_pooling_metadata()
if vllm_version_is("0.12.0"):
pooling_metadata.build_pooling_cursor(
num_scheduled_tokens_np.tolist(), device=hidden_states.device)
else:
pooling_metadata.build_pooling_cursor(
num_scheduled_tokens_np.tolist(),
seq_lens_cpu,
device=hidden_states.device)

model = cast(VllmModelForPooling, self.model)
raw_pooler_output = model.pooler(
raw_pooler_output: PoolerOutput = model.pooler(
hidden_states=hidden_states,
pooling_metadata=pooling_metadata,
)
raw_pooler_output = json_map_leaves(
lambda x: x.to("cpu", non_blocking=True),
lambda x: x.to("cpu", non_blocking=True) if x is not None else x,
raw_pooler_output,
)
torch.npu.synchronize()

pooler_output: list[Optional[torch.Tensor]] = []
pooler_output: list[torch.Tensor | None] = []
for raw_output, seq_len, prompt_len in zip(
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
output = raw_output if seq_len == prompt_len else None
Expand All @@ -2173,7 +2178,6 @@ def _pool(
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
kv_connector_output=kv_connector_output,
)

def _select_moe_comm_method(self,
Expand Down Expand Up @@ -2364,8 +2368,7 @@ def execute_model(
pool_output = self._pool(
hidden_states,
scheduler_output.total_num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving, kv_connector_output)
num_scheduled_tokens_np)
if need_dump:
assert self.debugger is not None
self.debugger.stop()
Expand Down Expand Up @@ -3100,35 +3103,51 @@ def _dummy_pooler_run_task(

req_num_tokens = num_tokens // num_reqs

dummy_prompt_lens = torch.tensor(
num_scheduled_tokens_list,
device="cpu",
)
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
device=self.device)

model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task)
dummy_pooling_params.verify(task=task, model_config=self.model_config)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params)

dummy_prompt_lens = torch.tensor(
num_scheduled_tokens_list,
device="cpu",
)
dummy_metadata = PoolingMetadata(
prompt_lens=dummy_prompt_lens,
prompt_token_ids=dummy_token_ids,
pooling_params=[dummy_pooling_params] * num_reqs,
)
if vllm_version_is("0.12.0"):
dummy_metadata = PoolingMetadata(
prompt_lens=dummy_prompt_lens,
prompt_token_ids=dummy_token_ids,
pooling_params=[dummy_pooling_params] * num_reqs,
)
dummy_metadata.build_pooling_cursor(
num_scheduled_tokens_list,
device=hidden_states.device,
)
else:
dummy_metadata = PoolingMetadata(
prompt_lens=dummy_prompt_lens,
prompt_token_ids=dummy_token_ids,
pooling_params=[dummy_pooling_params] * num_reqs,
pooling_states=[PoolingStates() for i in range(num_reqs)],
)

dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list,
device=hidden_states.device)
dummy_metadata.build_pooling_cursor(
num_scheduled_tokens_list,
seq_lens_cpu=dummy_prompt_lens,
device=hidden_states.device,
)

try:
return model.pooler(hidden_states=hidden_states,
pooling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
if "out of memory" in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up pooler "
"NPU out of memory occurred when warming up pooler "
f"({task=}) with {num_reqs} dummy requests. Please try "
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine.") from e
Expand All @@ -3141,8 +3160,17 @@ def _dummy_pooler_run(
hidden_states: torch.Tensor,
) -> PoolerOutput:
# Find the task that has the largest output for subsequent steps
supported_pooling_tasks = self.get_supported_pooling_tasks()

if not supported_pooling_tasks:
raise RuntimeError(
f"Model {self.model_config.model} does not support "
"any pooling tasks. See "
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
"to learn more.")

output_size = dict[PoolingTask, float]()
for task in self.get_supported_pooling_tasks():
for task in supported_pooling_tasks:
# Run a full batch with each task to ensure none of them OOMs
output = self._dummy_pooler_run_task(hidden_states, task)
output_size[task] = sum(o.nbytes for o in output)
Expand Down Expand Up @@ -4134,12 +4162,21 @@ def _get_prompt_logprobs_dict(

return prompt_logprobs_dict

def get_supported_pooling_tasks(self):
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []

return list(model.pooler.get_supported_tasks())
supported_tasks = list(model.pooler.get_supported_tasks())

if "score" in supported_tasks:
num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
if num_labels != 1:
supported_tasks.remove("score")
logger.debug_once(
"Score API is only enabled for num_labels == 1.")

return supported_tasks

def _build_drafter_prepare_inputs_torchair_param(self):
return False
Expand Down
52 changes: 38 additions & 14 deletions vllm_ascend/worker/npu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice

from vllm_ascend.pool.medatata import PoolingStates
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.worker.block_table import MultiGroupBlockTable


Expand All @@ -49,7 +51,6 @@ class CachedRequestState:
req_id: str
prompt_token_ids: Optional[list[int]]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]

block_ids: tuple[list[int], ...]
Expand All @@ -65,12 +66,18 @@ class CachedRequestState:
mm_positions: Optional[list[PlaceholderRange]] = None
mm_hashes: Optional[list[PlaceholderRange]] = None

# for pooling models
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None

lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None

def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
if self.pooling_params is not None:
self.pooling_states = PoolingStates()

@property
def num_tokens(self) -> int:
Expand Down Expand Up @@ -301,7 +308,9 @@ def __init__(
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()

# for pooling models
self.pooling_params: dict[str, PoolingParams] = {}
self.pooling_states: dict[str, PoolingStates] = {}

# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -451,11 +460,15 @@ def add_request(
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
elif pooling_params := request.pooling_params:
pooling_states = request.pooling_states
assert pooling_states is not None

self.pooling_params[req_id] = pooling_params
self.pooling_states[req_id] = pooling_states
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
raise NotImplementedError(request)
raise NotImplementedError("Unrecognized request type")

# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
Expand Down Expand Up @@ -522,7 +535,10 @@ def remove_request(self, req_id: str) -> Optional[int]:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
self.pooling_params.pop(req_id, None)
if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
self.pooling_states.pop(req_id, None)
return req_index
return req_index

def swap_states(self, i1: int, i2: int) -> None:
Expand Down Expand Up @@ -791,23 +807,31 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
logitsprocs=self.logitsprocs,
)

@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
def get_pooling_params(self) -> list[PoolingParams]:
assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids]

def get_pooling_states(self) -> list[PoolingStates]:
assert len(self.req_ids) == len(self.pooling_states)
return [self.pooling_states[req_id] for req_id in self.req_ids]

def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()

if vllm_version_is("0.12.0"):
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)

return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
pooling_states=self.get_pooling_states(),
)

def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
Expand Down
Loading