Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/full_tests/ci_perf_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,4 @@ vllm bench throughput \
--backend vllm \
--dataset_path ShareGPT_V3_unfiltered_cleaned_split.json \
--dataset_name sharegpt \
--num-prompts 1000 \
--max-model-len 32768
--num-prompts 1000
10 changes: 1 addition & 9 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,6 @@ def get_corrector(is_prompt, use_contiguous_pa):
else:
return correct_for_max_model_len

def get_max_bucket_per_query(bs, query):
return (bs, query, math.ceil((max_model_len - query) // block_size))

buckets = set()
buckets_2d = set()
omitted_buckets = set()
Expand All @@ -446,14 +443,9 @@ def get_max_bucket_per_query(bs, query):
local_buckets = expand_to_neighbor_buckets(bs_idx, bs_range, ctx_idx, ctx_range,
max_num_batched_tokens) if not is_prompt else {(bs, ctx)}
buckets_2d.update(local_buckets)
max_ctx = max(ctx for _, ctx in buckets_2d)

for bs, ctx in buckets_2d:
is_max_ctx = ctx == max_ctx
for query in query_range:
if is_prompt and is_max_ctx:
bs, query, edge_ctx = get_max_bucket_per_query(bs, query)
if edge_ctx >= 0:
ctx = edge_ctx
if all(bucket_filter(bs, query, ctx) for bucket_filter in filters):
buckets.add(corrector(bs, query, ctx))
if not buckets:
Expand Down
28 changes: 18 additions & 10 deletions vllm_gaudi/extension/bucketing/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class ExponentialBucketingStrategy():
long_context: bool = False

def check_for_user_flags(self, phase):
dim = ['bs', 'seq'] if phase == 'prompt' else ['bs', 'block']
Expand Down Expand Up @@ -39,15 +38,13 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke
else:
query_min = block_size
use_merged_prefill = get_config().merged_prefill
self.long_context = max_model_len >= 8192

# cfgs shape: [min, step, max, limit]
prompt_bs_limit = math.ceil(math.log2(max_num_prefill_seqs)) + 1
prompt_bs_bucket_cfg = [1, 2, max_num_prefill_seqs, prompt_bs_limit]
max_prompt_seq_limit = math.ceil(math.log2(max_num_batched_tokens))
prompt_query_bucket_cfg = [query_min, block_size, max_num_batched_tokens, max_prompt_seq_limit]
# Max ctx for all queries; later we generate additional buckets for max ctx per query
max_ctx = max(1, math.ceil((max_model_len - max_num_batched_tokens) // block_size))
max_ctx = max(1, math.ceil((max_model_len - prompt_query_bucket_cfg[0]) // block_size))
max_prompt_ctx_limit = 2 if max_ctx == 1 else math.ceil(math.log2(max_ctx)) + 1
prompt_ctx_bucket_cfg = [0, 1, max_ctx, max_prompt_ctx_limit]

Expand Down Expand Up @@ -98,7 +95,7 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_
return decode_bs_bucket_cfg, decode_query_bucket_cfg, decode_block_bucket_cfg

def get_range(self, cfg):
range_for_cfg = warmup_range_with_limit(cfg, self.long_context)
range_for_cfg = warmup_range_with_limit(cfg)
return sorted(range_for_cfg)


Expand Down Expand Up @@ -138,7 +135,11 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
if add_zero_or_one_bucket:
bmin_origin = bmin
bmin = bstep
linear_buckets = set(np.arange(bmin, bmax + 1, step=bstep))
assert num_buckets > 0, "num_buckets must be a positive integer"
if num_buckets == 1:
return [bmax]
buckets: Set[Tuple[int, int]] = set()

if long_context:
num_buckets_exp = math.floor(num_buckets / 2)
Expand All @@ -148,11 +149,6 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
num_buckets_exp = num_buckets
first_step = bmax

if num_buckets_exp <= 1:
return [bmax]

buckets: Set[Tuple[int, int]] = set()

for i in range(num_buckets_exp):
power_unpadded = bmin * np.float_power(first_step / bmin, (1. / float(num_buckets_exp - 1)) * i)
if i == num_buckets - 1 and get_config().use_contiguous_pa:
Expand All @@ -172,6 +168,18 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
bucket = bmax
else:
bucket = math.ceil(power_unpadded / bstep) * bstep
'''if fill and bucket in buckets:
available_buckets = linear_buckets.difference(buckets)
if len(available_buckets) == 0:
break # there are no more unique buckets, let's exit now
new_bucket = min(available_buckets,
key=lambda x: abs(x - power_unpadded))
if new_bucket not in buckets:
buckets.add(new_bucket)
else:
if bucket not in buckets:
buckets.add(bucket)
'''
if bucket not in buckets:
buckets.add(bucket)
if add_zero_or_one_bucket:
Expand Down