diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 7992cf8a..13565fa9 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -245,6 +245,8 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, query_idx, query_range, max_num (bs_idx + 1, query_idx + 1)] valid = bs_range[bs_idx] * query_range[query_idx] <= max_num_batched_tokens if not valid: + omitted_buckets.add(("bs_range[bs_idx] * query_range[query_idx] <= max_num_batched_tokens", + "-> bs, quesry: ", bs_idx, query_idx)) return {} valid_candidates = [(b_idx, q_idx) for b_idx, q_idx in candidates if b_idx < len(bs_range) and q_idx < len(query_range)] @@ -253,17 +255,30 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, query_idx, query_range, max_num # filter rules for buckets # prompt def not_over_max_model_len(bs, query, ctx): + if not query + ctx * block_size <= max_model_len: + omitted_buckets.add( + ("condition: query + ctx * block_size <= max_model_len", "-> bs, quesry, ctx: ", bs, query, ctx)) return query + ctx * block_size <= max_model_len def ctx_not_over_max_ctx_for_merged_prefill(bs, query, ctx): + if not ctx <= max_num_prefill_seqs * math.ceil( + (max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size): + omitted_buckets.add(( + "ctx <= max_num_prefill_seqs * math.ceil((max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)", + "-> bs, quesry, ctx: ", bs, query, ctx)) return ctx <= max_num_prefill_seqs * math.ceil( (max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size) # decode def block_not_greater_than_max_model_len(bs, query, ctx): + if not ctx <= bs * math.ceil(max_model_len / block_size): + omitted_buckets.add(("condition: ctx <= bs * math.ceil(max_model_len / block_size)", "-> bs, quesry, ctx: ", + bs, query, ctx)) return ctx <= bs * math.ceil(max_model_len / block_size) def batch_size_smaller_than_blocks(bs, query, ctx): + if not bs <= ctx: + omitted_buckets.add(("condition: bs <= ctx, ", "-> bs, query, ctx: ", bs, query, ctx)) return bs <= ctx filters_map = { @@ -289,6 +304,7 @@ def get_filters(is_prompt, use_merged_prefill, use_contiguous_pa): buckets = set() buckets_2d = set() + omitted_buckets = set() filters = get_filters(is_prompt, use_merged_prefill, use_contiguous_pa) for bs_idx, bs in enumerate(bs_range): for query_idx, query in enumerate(query_range): @@ -299,6 +315,15 @@ def get_filters(is_prompt, use_merged_prefill, use_contiguous_pa): for ctx in ctx_range: if all(bucket_filter(bs, query, ctx) for bucket_filter in filters): buckets.add((bs, query, ctx)) + if not buckets: + phase = 'prompt' if is_prompt else 'decode' + for bucket in omitted_buckets: + logger().error(bucket) + raise RuntimeError( + "Generated 0 " + phase + + " buckets. Please use default exponential bucketing, VLLM_EXPONENTIAL_BUCKETING=true or generate linear warmup flags according to README" + ) + return sorted(buckets)