diff --git a/tests/full_tests/ci_perf_tests.sh b/tests/full_tests/ci_perf_tests.sh index fb94a1956..3704c04a2 100644 --- a/tests/full_tests/ci_perf_tests.sh +++ b/tests/full_tests/ci_perf_tests.sh @@ -36,5 +36,4 @@ vllm bench throughput \ --backend vllm \ --dataset_path ShareGPT_V3_unfiltered_cleaned_split.json \ --dataset_name sharegpt \ - --num-prompts 1000 \ - --max-model-len 32768 + --num-prompts 1000 diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 2bfeed264..4c94f9f28 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -427,9 +427,6 @@ def get_corrector(is_prompt, use_contiguous_pa): else: return correct_for_max_model_len - def get_max_bucket_per_query(bs, query): - return (bs, query, math.ceil((max_model_len - query) // block_size)) - buckets = set() buckets_2d = set() omitted_buckets = set() @@ -446,14 +443,9 @@ def get_max_bucket_per_query(bs, query): local_buckets = expand_to_neighbor_buckets(bs_idx, bs_range, ctx_idx, ctx_range, max_num_batched_tokens) if not is_prompt else {(bs, ctx)} buckets_2d.update(local_buckets) - max_ctx = max(ctx for _, ctx in buckets_2d) + for bs, ctx in buckets_2d: - is_max_ctx = ctx == max_ctx for query in query_range: - if is_prompt and is_max_ctx: - bs, query, edge_ctx = get_max_bucket_per_query(bs, query) - if edge_ctx >= 0: - ctx = edge_ctx if all(bucket_filter(bs, query, ctx) for bucket_filter in filters): buckets.add(corrector(bs, query, ctx)) if not buckets: diff --git a/vllm_gaudi/extension/bucketing/exponential.py b/vllm_gaudi/extension/bucketing/exponential.py index d43d8d12b..87a4d1aee 100644 --- a/vllm_gaudi/extension/bucketing/exponential.py +++ b/vllm_gaudi/extension/bucketing/exponential.py @@ -11,7 +11,6 @@ class ExponentialBucketingStrategy(): - long_context: bool = False def check_for_user_flags(self, phase): dim = ['bs', 'seq'] if phase == 'prompt' else ['bs', 'block'] @@ -39,15 +38,13 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke else: query_min = block_size use_merged_prefill = get_config().merged_prefill - self.long_context = max_model_len >= 8192 # cfgs shape: [min, step, max, limit] prompt_bs_limit = math.ceil(math.log2(max_num_prefill_seqs)) + 1 prompt_bs_bucket_cfg = [1, 2, max_num_prefill_seqs, prompt_bs_limit] max_prompt_seq_limit = math.ceil(math.log2(max_num_batched_tokens)) prompt_query_bucket_cfg = [query_min, block_size, max_num_batched_tokens, max_prompt_seq_limit] - # Max ctx for all queries; later we generate additional buckets for max ctx per query - max_ctx = max(1, math.ceil((max_model_len - max_num_batched_tokens) // block_size)) + max_ctx = max(1, math.ceil((max_model_len - prompt_query_bucket_cfg[0]) // block_size)) max_prompt_ctx_limit = 2 if max_ctx == 1 else math.ceil(math.log2(max_ctx)) + 1 prompt_ctx_bucket_cfg = [0, 1, max_ctx, max_prompt_ctx_limit] @@ -98,7 +95,7 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_ return decode_bs_bucket_cfg, decode_query_bucket_cfg, decode_block_bucket_cfg def get_range(self, cfg): - range_for_cfg = warmup_range_with_limit(cfg, self.long_context) + range_for_cfg = warmup_range_with_limit(cfg) return sorted(range_for_cfg) @@ -138,7 +135,11 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals if add_zero_or_one_bucket: bmin_origin = bmin bmin = bstep + linear_buckets = set(np.arange(bmin, bmax + 1, step=bstep)) assert num_buckets > 0, "num_buckets must be a positive integer" + if num_buckets == 1: + return [bmax] + buckets: Set[Tuple[int, int]] = set() if long_context: num_buckets_exp = math.floor(num_buckets / 2) @@ -148,11 +149,6 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals num_buckets_exp = num_buckets first_step = bmax - if num_buckets_exp <= 1: - return [bmax] - - buckets: Set[Tuple[int, int]] = set() - for i in range(num_buckets_exp): power_unpadded = bmin * np.float_power(first_step / bmin, (1. / float(num_buckets_exp - 1)) * i) if i == num_buckets - 1 and get_config().use_contiguous_pa: @@ -172,6 +168,18 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals bucket = bmax else: bucket = math.ceil(power_unpadded / bstep) * bstep + '''if fill and bucket in buckets: + available_buckets = linear_buckets.difference(buckets) + if len(available_buckets) == 0: + break # there are no more unique buckets, let's exit now + new_bucket = min(available_buckets, + key=lambda x: abs(x - power_unpadded)) + if new_bucket not in buckets: + buckets.add(new_bucket) + else: + if bucket not in buckets: + buckets.add(bucket) + ''' if bucket not in buckets: buckets.add(bucket) if add_zero_or_one_bucket: