Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, query_idx, query_range, max_num
(bs_idx + 1, query_idx + 1)]
valid = bs_range[bs_idx] * query_range[query_idx] <= max_num_batched_tokens
if not valid:
omitted_buckets.add(("bs_range[bs_idx] * query_range[query_idx] <= max_num_batched_tokens",
"-> bs, quesry: ", bs_idx, query_idx))
return {}
valid_candidates = [(b_idx, q_idx) for b_idx, q_idx in candidates
if b_idx < len(bs_range) and q_idx < len(query_range)]
Expand All @@ -253,17 +255,30 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, query_idx, query_range, max_num
# filter rules for buckets
# prompt
def not_over_max_model_len(bs, query, ctx):
if not query + ctx * block_size <= max_model_len:
omitted_buckets.add(
("condition: query + ctx * block_size <= max_model_len", "-> bs, quesry, ctx: ", bs, query, ctx))
return query + ctx * block_size <= max_model_len

def ctx_not_over_max_ctx_for_merged_prefill(bs, query, ctx):
if not ctx <= max_num_prefill_seqs * math.ceil(
(max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size):
omitted_buckets.add((
"ctx <= max_num_prefill_seqs * math.ceil((max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)",
"-> bs, quesry, ctx: ", bs, query, ctx))
return ctx <= max_num_prefill_seqs * math.ceil(
(max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)

# decode
def block_not_greater_than_max_model_len(bs, query, ctx):
if not ctx <= bs * math.ceil(max_model_len / block_size):
omitted_buckets.add(("condition: ctx <= bs * math.ceil(max_model_len / block_size)", "-> bs, quesry, ctx: ",
bs, query, ctx))
return ctx <= bs * math.ceil(max_model_len / block_size)

def batch_size_smaller_than_blocks(bs, query, ctx):
if not bs <= ctx:
omitted_buckets.add(("condition: bs <= ctx, ", "-> bs, query, ctx: ", bs, query, ctx))
return bs <= ctx

filters_map = {
Expand All @@ -289,6 +304,7 @@ def get_filters(is_prompt, use_merged_prefill, use_contiguous_pa):

buckets = set()
buckets_2d = set()
omitted_buckets = set()
filters = get_filters(is_prompt, use_merged_prefill, use_contiguous_pa)
for bs_idx, bs in enumerate(bs_range):
for query_idx, query in enumerate(query_range):
Expand All @@ -299,6 +315,15 @@ def get_filters(is_prompt, use_merged_prefill, use_contiguous_pa):
for ctx in ctx_range:
if all(bucket_filter(bs, query, ctx) for bucket_filter in filters):
buckets.add((bs, query, ctx))
if not buckets:
phase = 'prompt' if is_prompt else 'decode'
for bucket in omitted_buckets:
logger().error(bucket)
raise RuntimeError(
"Generated 0 " + phase +
" buckets. Please use default exponential bucketing, VLLM_EXPONENTIAL_BUCKETING=true or generate linear warmup flags according to README"
)

return sorted(buckets)


Expand Down