Skip to content

Commit eea1826

Browse files
sparse to adapt new connector
1 parent 42a5ab5 commit eea1826

File tree

7 files changed

+234
-135
lines changed

7 files changed

+234
-135
lines changed

examples/offline_inference_esa.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,19 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
6666
kv_connector=name,
6767
kv_connector_module_path=module_path,
6868
kv_role="kv_both",
69+
# kv_connector_extra_config={
70+
# "UCM_CONFIG_FILE": "/home/externals/wangwenxin21/va_new/unified-cache-management/examples/ucm_config_example.yaml"
71+
# },
6972
kv_connector_extra_config={
70-
"ucm_connector_name": "UcmNfsStore",
71-
"ucm_connector_config": {
72-
"storage_backends": data_dir,
73-
"kv_block_size": 33554432,
74-
},
73+
"ucm_connectors": [
74+
{
75+
"ucm_connector_name": "UcmNfsStore",
76+
"ucm_connector_config": {
77+
"storage_backends": data_dir,
78+
"use_direct": False,
79+
},
80+
}
81+
],
7582
"ucm_sparse_config": {
7683
"ESA": {
7784
"init_window_sz": 1,
@@ -125,8 +132,8 @@ def print_output(
125132

126133

127134
def main():
128-
module_path = "ucm.integration.vllm.uc_connector"
129-
name = "UnifiedCacheConnectorV1"
135+
module_path = "ucm.integration.vllm.ucm_connector"
136+
name = "UCMConnector"
130137
setup_environment_variables()
131138

132139
def get_prompt(prompt):

examples/offline_inference_kvcomp.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,15 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
6767
kv_connector_module_path=module_path,
6868
kv_role="kv_both",
6969
kv_connector_extra_config={
70-
"ucm_connector_name": "UcmNfsStore",
71-
"ucm_connector_config": {
72-
"storage_backends": data_dir,
73-
"kv_block_size": 33554432,
74-
},
70+
"ucm_connectors": [
71+
{
72+
"ucm_connector_name": "UcmNfsStore",
73+
"ucm_connector_config": {
74+
"storage_backends": data_dir,
75+
"use_direct": False,
76+
},
77+
}
78+
],
7579
"ucm_sparse_config": {
7680
"KvComp": {
7781
"init_window_sz": 1,
@@ -123,8 +127,8 @@ def print_output(
123127

124128

125129
def main():
126-
module_path = "ucm.integration.vllm.uc_connector"
127-
name = "UnifiedCacheConnectorV1"
130+
module_path = "ucm.integration.vllm.ucm_connector"
131+
name = "UCMConnector"
128132
setup_environment_variables()
129133

130134
def get_prompt(prompt):

examples/offline_inference_kvstar.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,15 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
6868
kv_connector_module_path=module_path,
6969
kv_role="kv_both",
7070
kv_connector_extra_config={
71-
"ucm_connector_name": "UcmNfsStore",
72-
"ucm_connector_config": {
73-
"storage_backends": data_dir,
74-
"kv_block_size": 33554432,
75-
},
71+
"ucm_connectors": [
72+
{
73+
"ucm_connector_name": "UcmNfsStore",
74+
"ucm_connector_config": {
75+
"storage_backends": data_dir,
76+
"use_direct": False,
77+
},
78+
}
79+
],
7680
"ucm_sparse_config": {
7781
"KVStarMultiStep": {
7882
"init_window_sz": 1,
@@ -123,8 +127,8 @@ def print_output(
123127

124128

125129
def main():
126-
module_path = "ucm.integration.vllm.uc_connector"
127-
name = "UnifiedCacheConnectorV1"
130+
module_path = "ucm.integration.vllm.ucm_connector"
131+
name = "UCMConnector"
128132
setup_environment_variables()
129133

130134
def get_prompt(prompt):

ucm/sparse/esa/esa.py

Lines changed: 89 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ucm.sparse.esa.retrieval.retrieval_worker import RetrievalWorker
2727
from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank
2828
from ucm.store.ucmstore import Task, UcmKVStoreBase
29+
from ucm.integration.vllm.ucm_connector import RequestHasher
2930

3031
ReqType = Union[str, int]
3132
HashType = Union[str, int]
@@ -61,6 +62,7 @@ class ReqMeta:
6162
prompt_token_ids: list[int]
6263
output_token_ids: list[int]
6364
is_preempt: bool
65+
ucm_block_hashes:list[str]
6466

6567
@property
6668
def num_prompt_tokens(self) -> int:
@@ -100,6 +102,7 @@ def add_request(
100102
prompt_token_ids: list[int],
101103
output_token_ids: list[int],
102104
is_preempt: bool,
105+
ucm_block_hashes:list[str],
103106
) -> None:
104107

105108
meta = ReqMeta(
@@ -112,6 +115,7 @@ def add_request(
112115
prompt_token_ids=prompt_token_ids,
113116
output_token_ids=output_token_ids,
114117
is_preempt=is_preempt,
118+
ucm_block_hashes=ucm_block_hashes,
115119
)
116120
self.requests.append(meta)
117121

@@ -138,21 +142,30 @@ def get_sparse_range(init_window_sz, local_window_sz, prompt_len, block_size):
138142
sparse_range = num_blocks_upper_bound - init_window_sz - local_window_sz
139143
return sparse_range
140144

141-
142145
@cache
143-
def md5(input) -> int:
144-
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
145-
md5_bytes = hashlib.md5(input_bytes).digest()
146-
return int.from_bytes(md5_bytes, byteorder="big")
146+
def compute_parent_block_hash(model_name, world_size, dtype, seed_rank=0) -> int:
147+
meta = f"{model_name}:{world_size}:{dtype}:{seed_rank}"
148+
meta_bytes = meta.encode("utf-8")
149+
h_seed = hashlib.md5(meta_bytes + b"UCM_HASH_SEED").digest()
150+
return int.from_bytes(h_seed, byteorder="big")
147151

148152

149153
@cache
150-
def block_hash_func(parent_block_hash, curr_block_token_ids):
151-
if not parent_block_hash:
152-
parent_block_hash = md5("UCMHASHSEED")
153-
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
154-
return md5((parent_block_hash, curr_block_token_ids_tuple))
154+
def compute_layer_offset(
155+
block_data_size: int,
156+
layer_id: int,
157+
is_v: bool,
158+
is_mla: bool,
159+
) -> int:
160+
layer_data_size = block_data_size if is_mla else block_data_size * 2
161+
162+
k_offset = layer_data_size * layer_id
155163

164+
if is_mla:
165+
return k_offset
166+
167+
v_offset = k_offset + block_data_size
168+
return v_offset if is_v else k_offset
156169

157170
def task_hash_func(block_ids, store_type, tensor_type):
158171
return hash((tuple(block_ids), store_type, tensor_type))
@@ -178,7 +191,6 @@ def diff_two_map(map1: dict, map2: dict):
178191

179192
class ReqStatePerLayer:
180193
# handle single request per layer
181-
182194
def __init__(
183195
self,
184196
layer_name: str,
@@ -222,68 +234,37 @@ def __init__(
222234
self.head_size = vllm_config.model_config.get_head_size()
223235
self.is_mla = self.vllm_config.model_config.is_deepseek_mla
224236
self.step = 0
225-
226-
def set_block_hashes(self, token_ids):
227-
if self.block_hashes is not None:
228-
return
229-
self.block_hashes = []
230-
parent_block_hash_value = None
231-
num_total_blocks = math.ceil(len(token_ids) / self.block_size)
232-
for start in range(0, len(token_ids), self.block_size):
233-
end = start + self.block_size
234-
block_idx = start // self.block_size
235-
if block_idx >= num_total_blocks - self.esa_cfg["local_window_sz"]:
236-
continue
237-
block_token_ids = token_ids[start:end]
238-
if len(block_token_ids) < self.block_size:
239-
break
240-
curr_block_token_ids_tuple = tuple(block_token_ids)
241-
block_hash = block_hash_func(
242-
parent_block_hash_value, curr_block_token_ids_tuple
243-
)
244-
if block_idx >= self.esa_cfg["init_window_sz"]:
245-
self.block_hashes.append(str(block_hash))
246-
parent_block_hash_value = block_hash
247-
237+
248238
def update_meta(self, req_meta: ReqMeta):
249239
self.req_meta = req_meta
250240

251241
def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids):
252242
fn = getattr(self.store_instance, transfer_type)
253243
length = len(block_hashes)
254-
block_shape = (self.block_size, self.num_key_heads, self.head_size)
255244
precision = self.vllm_config.model_config.dtype.itemsize
256-
257-
block_shape = tuple(block_shape)
258-
offsets_k = [
259-
get_offset(
260-
block_shape,
261-
self.rank,
262-
self.tp_size,
263-
precision,
264-
self.layer_id,
265-
is_v=False,
266-
is_mla=self.is_mla,
267-
)
268-
] * length
269-
245+
block_data_size = self.k_cache[0].numel() * precision
246+
247+
offset_k = compute_layer_offset(
248+
block_data_size,
249+
self.layer_id,
250+
is_v=False,
251+
is_mla=self.is_mla,
252+
)
253+
offsets_k = [offset_k] * length
254+
270255
key_src_tensors = [self.k_cache[id_] for id_ in vllm_block_ids]
271256
task_k = fn(block_hashes, offsets_k, key_src_tensors)
272257
task_k_hash = task_hash_func(block_hashes, transfer_type, "key")
273258
self.tasks[task_k_hash] = task_k
274259

275260
if not self.is_mla:
276-
offsets_v = [
277-
get_offset(
278-
block_shape,
279-
self.rank,
280-
self.tp_size,
281-
precision,
282-
self.layer_id,
283-
is_v=True,
284-
is_mla=self.is_mla,
285-
)
286-
] * length
261+
offset_v = compute_layer_offset(
262+
block_data_size,
263+
self.layer_id,
264+
is_v=True,
265+
is_mla=self.is_mla,
266+
)
267+
offsets_v = [offset_v] * length
287268
value_src_tensors = [self.v_cache[id_] for id_ in vllm_block_ids]
288269
task_v = fn(block_hashes, offsets_v, value_src_tensors)
289270
task_v_hash = task_hash_func(block_hashes, transfer_type, "value")
@@ -303,7 +284,7 @@ def maybe_register_static_data(self, forward_context: ForwardContext):
303284
else:
304285
self.k_cache = kv_cache[0]
305286
self.v_cache = kv_cache[1]
306-
self.set_block_hashes(self.req_meta.prompt_token_ids)
287+
self.block_hashes = self.req_meta.ucm_block_hashes
307288
self.init_static_flag = True
308289

309290
def wait_transfer_task_done(self):
@@ -461,7 +442,6 @@ def attention_finished(
461442
self.wait_retrieval_and_start_load()
462443
self.step += 1
463444

464-
465445
class ESA(UcmSparseBase):
466446
# handle batch
467447
def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
@@ -470,7 +450,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
470450
self.rank = vllm_config.parallel_config.rank
471451
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
472452
if role == UcmSparseRole.WORKER:
473-
self.connector = get_kv_transfer_group().connector
453+
self.connector = get_kv_transfer_group().connector.store
474454
else:
475455
self.connector = None
476456
self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[
@@ -483,6 +463,9 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
483463
self._sparse_metadata_prefill: ESASparseMetaData = ESASparseMetaData()
484464
self._sparse_metadata_decode: ESASparseMetaData = ESASparseMetaData()
485465
self._sparse_metadata: ESASparseMetaData = ESASparseMetaData()
466+
self.request_hasher = RequestHasher(vllm_config, 0)
467+
self.block_size = vllm_config.cache_config.block_size
468+
self.block_hashes: dict[int, dict[int, list[str]]] = {}
486469
global data
487470

488471
if data is None:
@@ -601,7 +584,6 @@ def attention_finished(
601584
forward_context: ForwardContext,
602585
phase: Optional[str] = None,
603586
) -> None:
604-
605587
if not self.is_mla:
606588
for req_meta in self._sparse_metadata.requests:
607589
self.update_req_state_attention_end(
@@ -643,6 +625,45 @@ def is_sparsed_request(self, req):
643625
>= self._vllm_config.cache_config.block_size * self.esa_cfg["min_blocks"]
644626
)
645627

628+
def set_block_hashes(self, req_id, token_ids):
629+
if req_id not in self.block_hashes:
630+
self.block_hashes[req_id] = {}
631+
632+
if self.rank in self.block_hashes[req_id]:
633+
return
634+
635+
self.block_hashes[req_id][self.rank] = []
636+
637+
parent_block_hash_value = compute_parent_block_hash(
638+
self._vllm_config.model_config.model,
639+
self._vllm_config.parallel_config.world_size,
640+
self._vllm_config.model_config.dtype,
641+
seed_rank=0,
642+
)
643+
644+
num_total_blocks = math.ceil(len(token_ids) / self.block_size)
645+
for start in range(0, len(token_ids), self.block_size):
646+
end = start + self.block_size
647+
block_idx = start // self.block_size
648+
if block_idx >= num_total_blocks - self.esa_cfg["local_window_sz"]:
649+
continue
650+
block_token_ids = token_ids[start:end]
651+
if len(block_token_ids) < self.block_size:
652+
break
653+
curr_block_token_ids_tuple = tuple(block_token_ids)
654+
hash_value = self.request_hasher(
655+
(parent_block_hash_value, curr_block_token_ids_tuple)
656+
)
657+
if block_idx >= self.esa_cfg["init_window_sz"]:
658+
self.block_hashes[req_id][self.rank].append(str(hash_value))
659+
660+
parent_block_hash_value = hash_value
661+
662+
if self.rank != 0 and not self.is_mla:
663+
self.newqrequest_hasher = RequestHasher(self._vllm_config, self.rank)
664+
for i, ucm_block_id in enumerate(self.block_hashes[req_id][self.rank]):
665+
self.block_hashes[req_id][self.rank][i] = str(self.newqrequest_hasher(ucm_block_id))
666+
646667
def build_sparse_meta(
647668
self, scheduler_output, requests, input_batch, attn_metadata
648669
) -> UcmSparseMetadata:
@@ -654,7 +675,6 @@ def build_sparse_meta(
654675
req_ids = list(getattr(input_batch, "req_ids", []))
655676
decode_ids = [rid for rid in req_ids if num_sched.get(rid, 0) == 1]
656677
decode_set = set(decode_ids)
657-
658678
cached_reqs = scheduler_output.scheduled_cached_reqs
659679
preempt_reqs = set()
660680
if cached_reqs:
@@ -670,6 +690,7 @@ def build_sparse_meta(
670690
req = requests[req_id]
671691
if not self.is_sparsed_request(req):
672692
continue
693+
self.set_block_hashes(int(req_id), req.prompt_token_ids)
673694
if isinstance(attn_metadata, dict):
674695
attn_metadata = next(iter(attn_metadata.values()))
675696

@@ -684,6 +705,7 @@ def build_sparse_meta(
684705
req.prompt_token_ids,
685706
req.output_token_ids,
686707
req_id in preempt_reqs,
708+
self.block_hashes[int(req_id)][self.rank],
687709
)
688710

689711
else:
@@ -704,6 +726,7 @@ def build_sparse_meta(
704726
req.prompt_token_ids,
705727
req.output_token_ids,
706728
req_id in preempt_reqs,
729+
self.block_hashes[int(req_id)][self.rank],
707730
)
708731

709732
else:
@@ -720,6 +743,7 @@ def build_sparse_meta(
720743
req.prompt_token_ids,
721744
req.output_token_ids,
722745
req_id in preempt_reqs,
746+
self.block_hashes[int(req_id)][self.rank],
723747
)
724748

725749
# self._sparse_metadata = sparse_meta

0 commit comments

Comments
 (0)