2626from ucm .sparse .esa .retrieval .retrieval_worker import RetrievalWorker
2727from ucm .sparse .kvstar .utils import get_bind_cpus_for_rank
2828from ucm .store .ucmstore import Task , UcmKVStoreBase
29+ from ucm .integration .vllm .ucm_connector import RequestHasher
2930
3031ReqType = Union [str , int ]
3132HashType = Union [str , int ]
@@ -61,6 +62,7 @@ class ReqMeta:
6162 prompt_token_ids : list [int ]
6263 output_token_ids : list [int ]
6364 is_preempt : bool
65+ ucm_block_hashes :list [str ]
6466
6567 @property
6668 def num_prompt_tokens (self ) -> int :
@@ -100,6 +102,7 @@ def add_request(
100102 prompt_token_ids : list [int ],
101103 output_token_ids : list [int ],
102104 is_preempt : bool ,
105+ ucm_block_hashes :list [str ],
103106 ) -> None :
104107
105108 meta = ReqMeta (
@@ -112,6 +115,7 @@ def add_request(
112115 prompt_token_ids = prompt_token_ids ,
113116 output_token_ids = output_token_ids ,
114117 is_preempt = is_preempt ,
118+ ucm_block_hashes = ucm_block_hashes ,
115119 )
116120 self .requests .append (meta )
117121
@@ -138,21 +142,30 @@ def get_sparse_range(init_window_sz, local_window_sz, prompt_len, block_size):
138142 sparse_range = num_blocks_upper_bound - init_window_sz - local_window_sz
139143 return sparse_range
140144
141-
142145@cache
143- def md5 (input ) -> int :
144- input_bytes = pickle .dumps (input , protocol = pickle .HIGHEST_PROTOCOL )
145- md5_bytes = hashlib .md5 (input_bytes ).digest ()
146- return int .from_bytes (md5_bytes , byteorder = "big" )
146+ def compute_parent_block_hash (model_name , world_size , dtype , seed_rank = 0 ) -> int :
147+ meta = f"{ model_name } :{ world_size } :{ dtype } :{ seed_rank } "
148+ meta_bytes = meta .encode ("utf-8" )
149+ h_seed = hashlib .md5 (meta_bytes + b"UCM_HASH_SEED" ).digest ()
150+ return int .from_bytes (h_seed , byteorder = "big" )
147151
148152
149153@cache
150- def block_hash_func (parent_block_hash , curr_block_token_ids ):
151- if not parent_block_hash :
152- parent_block_hash = md5 ("UCMHASHSEED" )
153- curr_block_token_ids_tuple = tuple (curr_block_token_ids )
154- return md5 ((parent_block_hash , curr_block_token_ids_tuple ))
154+ def compute_layer_offset (
155+ block_data_size : int ,
156+ layer_id : int ,
157+ is_v : bool ,
158+ is_mla : bool ,
159+ ) -> int :
160+ layer_data_size = block_data_size if is_mla else block_data_size * 2
161+
162+ k_offset = layer_data_size * layer_id
155163
164+ if is_mla :
165+ return k_offset
166+
167+ v_offset = k_offset + block_data_size
168+ return v_offset if is_v else k_offset
156169
157170def task_hash_func (block_ids , store_type , tensor_type ):
158171 return hash ((tuple (block_ids ), store_type , tensor_type ))
@@ -178,7 +191,6 @@ def diff_two_map(map1: dict, map2: dict):
178191
179192class ReqStatePerLayer :
180193 # handle single request per layer
181-
182194 def __init__ (
183195 self ,
184196 layer_name : str ,
@@ -222,68 +234,37 @@ def __init__(
222234 self .head_size = vllm_config .model_config .get_head_size ()
223235 self .is_mla = self .vllm_config .model_config .is_deepseek_mla
224236 self .step = 0
225-
226- def set_block_hashes (self , token_ids ):
227- if self .block_hashes is not None :
228- return
229- self .block_hashes = []
230- parent_block_hash_value = None
231- num_total_blocks = math .ceil (len (token_ids ) / self .block_size )
232- for start in range (0 , len (token_ids ), self .block_size ):
233- end = start + self .block_size
234- block_idx = start // self .block_size
235- if block_idx >= num_total_blocks - self .esa_cfg ["local_window_sz" ]:
236- continue
237- block_token_ids = token_ids [start :end ]
238- if len (block_token_ids ) < self .block_size :
239- break
240- curr_block_token_ids_tuple = tuple (block_token_ids )
241- block_hash = block_hash_func (
242- parent_block_hash_value , curr_block_token_ids_tuple
243- )
244- if block_idx >= self .esa_cfg ["init_window_sz" ]:
245- self .block_hashes .append (str (block_hash ))
246- parent_block_hash_value = block_hash
247-
237+
248238 def update_meta (self , req_meta : ReqMeta ):
249239 self .req_meta = req_meta
250240
251241 def launch_transfer_task (self , transfer_type , block_hashes , vllm_block_ids ):
252242 fn = getattr (self .store_instance , transfer_type )
253243 length = len (block_hashes )
254- block_shape = (self .block_size , self .num_key_heads , self .head_size )
255244 precision = self .vllm_config .model_config .dtype .itemsize
256-
257- block_shape = tuple (block_shape )
258- offsets_k = [
259- get_offset (
260- block_shape ,
261- self .rank ,
262- self .tp_size ,
263- precision ,
264- self .layer_id ,
265- is_v = False ,
266- is_mla = self .is_mla ,
267- )
268- ] * length
269-
245+ block_data_size = self .k_cache [0 ].numel () * precision
246+
247+ offset_k = compute_layer_offset (
248+ block_data_size ,
249+ self .layer_id ,
250+ is_v = False ,
251+ is_mla = self .is_mla ,
252+ )
253+ offsets_k = [offset_k ] * length
254+
270255 key_src_tensors = [self .k_cache [id_ ] for id_ in vllm_block_ids ]
271256 task_k = fn (block_hashes , offsets_k , key_src_tensors )
272257 task_k_hash = task_hash_func (block_hashes , transfer_type , "key" )
273258 self .tasks [task_k_hash ] = task_k
274259
275260 if not self .is_mla :
276- offsets_v = [
277- get_offset (
278- block_shape ,
279- self .rank ,
280- self .tp_size ,
281- precision ,
282- self .layer_id ,
283- is_v = True ,
284- is_mla = self .is_mla ,
285- )
286- ] * length
261+ offset_v = compute_layer_offset (
262+ block_data_size ,
263+ self .layer_id ,
264+ is_v = True ,
265+ is_mla = self .is_mla ,
266+ )
267+ offsets_v = [offset_v ] * length
287268 value_src_tensors = [self .v_cache [id_ ] for id_ in vllm_block_ids ]
288269 task_v = fn (block_hashes , offsets_v , value_src_tensors )
289270 task_v_hash = task_hash_func (block_hashes , transfer_type , "value" )
@@ -303,7 +284,7 @@ def maybe_register_static_data(self, forward_context: ForwardContext):
303284 else :
304285 self .k_cache = kv_cache [0 ]
305286 self .v_cache = kv_cache [1 ]
306- self .set_block_hashes ( self .req_meta .prompt_token_ids )
287+ self .block_hashes = self .req_meta .ucm_block_hashes
307288 self .init_static_flag = True
308289
309290 def wait_transfer_task_done (self ):
@@ -461,7 +442,6 @@ def attention_finished(
461442 self .wait_retrieval_and_start_load ()
462443 self .step += 1
463444
464-
465445class ESA (UcmSparseBase ):
466446 # handle batch
467447 def __init__ (self , vllm_config : VllmConfig , role : UcmSparseRole ):
@@ -470,7 +450,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
470450 self .rank = vllm_config .parallel_config .rank
471451 self .tp_size = vllm_config .parallel_config .tensor_parallel_size
472452 if role == UcmSparseRole .WORKER :
473- self .connector = get_kv_transfer_group ().connector
453+ self .connector = get_kv_transfer_group ().connector . store
474454 else :
475455 self .connector = None
476456 self .esa_cfg = vllm_config .kv_transfer_config .kv_connector_extra_config [
@@ -483,6 +463,9 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
483463 self ._sparse_metadata_prefill : ESASparseMetaData = ESASparseMetaData ()
484464 self ._sparse_metadata_decode : ESASparseMetaData = ESASparseMetaData ()
485465 self ._sparse_metadata : ESASparseMetaData = ESASparseMetaData ()
466+ self .request_hasher = RequestHasher (vllm_config , 0 )
467+ self .block_size = vllm_config .cache_config .block_size
468+ self .block_hashes : dict [int , dict [int , list [str ]]] = {}
486469 global data
487470
488471 if data is None :
@@ -601,7 +584,6 @@ def attention_finished(
601584 forward_context : ForwardContext ,
602585 phase : Optional [str ] = None ,
603586 ) -> None :
604-
605587 if not self .is_mla :
606588 for req_meta in self ._sparse_metadata .requests :
607589 self .update_req_state_attention_end (
@@ -643,6 +625,45 @@ def is_sparsed_request(self, req):
643625 >= self ._vllm_config .cache_config .block_size * self .esa_cfg ["min_blocks" ]
644626 )
645627
628+ def set_block_hashes (self , req_id , token_ids ):
629+ if req_id not in self .block_hashes :
630+ self .block_hashes [req_id ] = {}
631+
632+ if self .rank in self .block_hashes [req_id ]:
633+ return
634+
635+ self .block_hashes [req_id ][self .rank ] = []
636+
637+ parent_block_hash_value = compute_parent_block_hash (
638+ self ._vllm_config .model_config .model ,
639+ self ._vllm_config .parallel_config .world_size ,
640+ self ._vllm_config .model_config .dtype ,
641+ seed_rank = 0 ,
642+ )
643+
644+ num_total_blocks = math .ceil (len (token_ids ) / self .block_size )
645+ for start in range (0 , len (token_ids ), self .block_size ):
646+ end = start + self .block_size
647+ block_idx = start // self .block_size
648+ if block_idx >= num_total_blocks - self .esa_cfg ["local_window_sz" ]:
649+ continue
650+ block_token_ids = token_ids [start :end ]
651+ if len (block_token_ids ) < self .block_size :
652+ break
653+ curr_block_token_ids_tuple = tuple (block_token_ids )
654+ hash_value = self .request_hasher (
655+ (parent_block_hash_value , curr_block_token_ids_tuple )
656+ )
657+ if block_idx >= self .esa_cfg ["init_window_sz" ]:
658+ self .block_hashes [req_id ][self .rank ].append (str (hash_value ))
659+
660+ parent_block_hash_value = hash_value
661+
662+ if self .rank != 0 and not self .is_mla :
663+ self .newqrequest_hasher = RequestHasher (self ._vllm_config , self .rank )
664+ for i , ucm_block_id in enumerate (self .block_hashes [req_id ][self .rank ]):
665+ self .block_hashes [req_id ][self .rank ][i ] = str (self .newqrequest_hasher (ucm_block_id ))
666+
646667 def build_sparse_meta (
647668 self , scheduler_output , requests , input_batch , attn_metadata
648669 ) -> UcmSparseMetadata :
@@ -654,7 +675,6 @@ def build_sparse_meta(
654675 req_ids = list (getattr (input_batch , "req_ids" , []))
655676 decode_ids = [rid for rid in req_ids if num_sched .get (rid , 0 ) == 1 ]
656677 decode_set = set (decode_ids )
657-
658678 cached_reqs = scheduler_output .scheduled_cached_reqs
659679 preempt_reqs = set ()
660680 if cached_reqs :
@@ -670,6 +690,7 @@ def build_sparse_meta(
670690 req = requests [req_id ]
671691 if not self .is_sparsed_request (req ):
672692 continue
693+ self .set_block_hashes (int (req_id ), req .prompt_token_ids )
673694 if isinstance (attn_metadata , dict ):
674695 attn_metadata = next (iter (attn_metadata .values ()))
675696
@@ -684,6 +705,7 @@ def build_sparse_meta(
684705 req .prompt_token_ids ,
685706 req .output_token_ids ,
686707 req_id in preempt_reqs ,
708+ self .block_hashes [int (req_id )][self .rank ],
687709 )
688710
689711 else :
@@ -704,6 +726,7 @@ def build_sparse_meta(
704726 req .prompt_token_ids ,
705727 req .output_token_ids ,
706728 req_id in preempt_reqs ,
729+ self .block_hashes [int (req_id )][self .rank ],
707730 )
708731
709732 else :
@@ -720,6 +743,7 @@ def build_sparse_meta(
720743 req .prompt_token_ids ,
721744 req .output_token_ids ,
722745 req_id in preempt_reqs ,
746+ self .block_hashes [int (req_id )][self .rank ],
723747 )
724748
725749 # self._sparse_metadata = sparse_meta
0 commit comments