Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm_ascend/patch/worker/patch_rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import vllm.v1.sample.rejection_sampler as rs

from vllm_ascend.sample.rejection_sampler import (expand_batch_to_tokens,
rejection_sample)

rs.expand_batch_to_tokens = expand_batch_to_tokens
rs.rejection_sample = rejection_sample
96 changes: 1 addition & 95 deletions vllm_ascend/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@
from typing import Optional

import torch
import torch.nn as nn
import vllm.v1.sample.rejection_sampler as rs
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
apply_sampling_constraints,
generate_uniform_probs)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.sample.rejection_sampler import generate_uniform_probs

PLACEHOLDER_TOKEN_ID = -1
GREEDY_TEMPERATURE = -1
Expand All @@ -18,92 +13,6 @@
MAX_SPEC_LEN = 32


class AscendRejectionSampler(RejectionSampler, nn.Module):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""

def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
target_logits = apply_sampling_constraints(
target_logits,
metadata.cu_num_draft_tokens,
sampling_metadata,
)
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)

output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
bonus_token_ids,
sampling_metadata,
)
return output_token_ids


def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,
Expand Down Expand Up @@ -777,6 +686,3 @@ def sample_recovered_tokens_kernel(
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)


rs.expand_batch_to_tokens = expand_batch_to_tokens
Loading
Loading