Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions examples/offline_inference_blend.py

Large diffs are not rendered by default.

316 changes: 305 additions & 11 deletions ucm/integration/vllm/patch/patch_funcs/v092/vllm_adapt.py

Large diffs are not rendered by default.

115 changes: 102 additions & 13 deletions ucm/integration/vllm/uc_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ucm.logger import init_logger
from ucm.store.factory import UcmConnectorFactory
from ucm.store.ucmstore import Task
from ucm.sparse.blend.chunk_processor import ChunkProcessor, ChunkMetaData, hash_token_ids

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
Expand Down Expand Up @@ -78,6 +79,8 @@ class ReqMeta:
dump_blocks: list[tuple[str, int]] = field(default_factory=list)
# Whether use load_async
load_async: bool = False
# blend rag cache
chunks_load_meta: list[ChunkMetaData] = field(default_factory=list)


@dataclass
Expand Down Expand Up @@ -160,6 +163,18 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
]
)

self.enable_blend = False
self.req2rag_load_chunks: dict[str, list[ChunkMetaData]] = {}
end_token_id = 0
if (("ucm_sparse_config" in self._vllm_config.kv_transfer_config.kv_connector_extra_config)
and "Blend" in self._vllm_config.kv_transfer_config.kv_connector_extra_config["ucm_sparse_config"]):
ucm_blend_config = self._vllm_config.kv_transfer_config.kv_connector_extra_config[ "ucm_sparse_config" ]["Blend"]
self.enable_blend = True
end_token_id = ucm_blend_config["chunk_end_token_id"]
self.chunk_processor = ChunkProcessor(
config={'chunk_end_token_id': end_token_id, 'block_size': self.block_size}
)

def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"):
for layer_name in forward_context.no_compile_layers:
attn_layer = forward_context.no_compile_layers[layer_name]
Expand Down Expand Up @@ -303,6 +318,10 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
self._load_failed_reqs.add(req_id)
logger.error(f"Failed to load blocks for req {req_id}")

def setup_model(self, model) -> None:
# get cos_sin_embedding cache for kv cache load post process (block-wise delta rope)
self.chunk_processor.setup_rotary_emb(model)

def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
Expand Down Expand Up @@ -333,6 +352,18 @@ def wait_for_layer_load(self, layer_name: str) -> None:
)
continue
logger.debug(f"Load tasks for {request_id} on layer {layer_name} finished.")
# prepare rerope for rag chunk k cache
k_cache = self.kv_caches[layer_name][0]
all_hits_vllm_ids = []
positions = []
for reqMeta in self._connector_metadata.requests:
for meta in reqMeta.chunks_load_meta:
all_hits_vllm_ids.extend(meta.hits_vllm_blk_ids)
positions.extend([meta.position_offset]*len(meta.hits_vllm_blk_ids))
if all_hits_vllm_ids:
vllm_ids = torch.tensor(all_hits_vllm_ids, device=k_cache.device)
positions = torch.tensor(positions, device=k_cache.device)
self.chunk_processor.process_chunk_cache(k_cache, vllm_ids, positions)

def save_kv_layer(
self,
Expand Down Expand Up @@ -548,8 +579,8 @@ def hash_request_tokens(
hash_value = hash_function(
(parent_block_hash_value, block_token_ids_tuple)
)
parent_block_hash_value = hash_value
ret.append(str(hash_value))
parent_block_hash_value = str(hash_value)
ret.append(parent_block_hash_value)

return ret

Expand All @@ -562,7 +593,10 @@ def hash_request_tokens(
# Calculate start position (exclude blocks already in HBM)
start_position = num_computed_tokens // self.block_size

block_operations = [BlockOperation.NONE] * len(block_hashes)
# state machine
# default to dump all blks, when blk allocation fail, turn to be NONE, when cache hit, turn to be LOAD
block_operations = ([BlockOperation.NONE] * start_position +
[BlockOperation.DUMP] * (len(block_hashes) - start_position))

remain_hashes = block_hashes[start_position:]
if not remain_hashes:
Expand All @@ -580,10 +614,57 @@ def hash_request_tokens(
else:
# TODO we will fix hole match later
break

num_rag_lookup_hits = 0
if self.enable_blend:
# for unmatched blocks, further match the rag chunk
rag_start_blk_idx = start_position + num_lookup_hits
rag_chunks_meta, is_build_cache = self.chunk_processor.process_request(request, md5, rag_start_blk_idx)

if not is_build_cache:
# blend stage
final_rag_chunks_meta = []
old_lookup_results = [False]
old_chunk_meta = None
for chunk_meta in rag_chunks_meta:
lookup_results = self.connector.lookup(chunk_meta.chunk_blks_hash)
chunk_meta.store_hits = lookup_results
final_rag_chunks_meta.append(chunk_meta)
if sum(lookup_results) == 0 and old_lookup_results[-1]:
# current whole chunk is miss and last chunk's last blk hit, try to merge chunk
merge_tokens = request.prompt_token_ids[chunk_meta.start_idx_in_req:chunk_meta.end_idx_in_req]
merge_chunk_blks_hash = hash_token_ids(md5, self.block_size, merge_tokens, old_chunk_meta.chunk_blks_hash[-1])
merge_lookup_results = self.connector.lookup(merge_chunk_blks_hash)
if merge_lookup_results[0]:
# current chunk meta need to merge into old chunk meta
chunk_meta.store_hits = merge_lookup_results
chunk_meta.chunk_blks_hash = merge_chunk_blks_hash
self.chunk_processor.merge_chunks(old_chunk_meta, chunk_meta)
final_rag_chunks_meta.pop()
for i, hit in enumerate(chunk_meta.store_hits):
# replace the origin pc hash with chunk pc hash
# maybe we should also invalid the block hash in HBM's block manager, cause after cache blend,
# the kv cache of rag chunk in HBM is recomputed, they can contact all chunks in this req.
block_hashes[rag_start_blk_idx] = chunk_meta.chunk_blks_hash[i]
if hit:
num_rag_lookup_hits += 1
block_operations[rag_start_blk_idx] = BlockOperation.LOAD
else:
# cache blend can recompute the missing hole, but this cache is no longer context independent
block_operations[rag_start_blk_idx] = BlockOperation.NONE
pass
rag_start_blk_idx += 1
old_chunk_meta = final_rag_chunks_meta[-1]
old_lookup_results = old_chunk_meta.store_hits

if num_rag_lookup_hits:
self.req2rag_load_chunks[request.request_id] = final_rag_chunks_meta

logger.info(
f"num_total_blocks: {len(block_hashes)}, "
f"num_lookup_hits on hbm: {start_position}, "
f"num_lookup_hits on storage except hbm: {num_lookup_hits}"
f"num_lookup_hits on storage except hbm: {num_lookup_hits}, "
f"num_lookup_hits on rag chunk: {num_rag_lookup_hits}"
)

# Load async when Decode instance need to load
Expand Down Expand Up @@ -643,17 +724,20 @@ def update_state_after_alloc(
block_operations = request_block_info.block_operations
block_hashes = request_block_info.block_hashes
start_create_pos = start_position + num_external_tokens // self.block_size
remaining_hashes = block_hashes[start_create_pos:]
if remaining_hashes:
create_results = self.connector.create(remaining_hashes)
need_dump_blks = []
need_dump_blks_idx = []
for idx in range(start_create_pos, len(block_hashes)):
# for chunk cache hit, no need to save
if block_operations[idx] == BlockOperation.DUMP:
need_dump_blks.append(block_hashes[idx])
need_dump_blks_idx.append(idx)
if need_dump_blks:
create_results = self.connector.create(need_dump_blks)
if any(ret != 0 for ret in create_results):
logger.warning(f"\ncreate_results on storage: {create_results}\n")
for j, ret in enumerate(create_results):
idx = start_create_pos + j
block_operations[idx] = (
BlockOperation.DUMP if ret == 0 else BlockOperation.NONE
)
# set start_position to 0, so that we can process from the beginning
for i, ret in enumerate(create_results):
block_operations[need_dump_blks_idx[i]] = BlockOperation.DUMP if ret == 0 else BlockOperation.NONE
# set start_position to 0, so that we can process from the beginning
request_block_info.start_position = 0

def build_connector_meta(
Expand Down Expand Up @@ -694,11 +778,16 @@ def build_connector_meta(
vllm_block_ids, block_info
)
if load_blocks or dump_blocks:
chunks_load_meta = self.req2rag_load_chunks.pop(req_id, [])
for chunk_meta in chunks_load_meta:
chunk_meta.vllm_blk_ids = vllm_block_ids[chunk_meta.start_idx_in_req_blks: chunk_meta.end_idx_in_req_blks]

meta.requests.append(
ReqMeta(
request_id=req_id,
load_blocks=load_blocks,
dump_blocks=dump_blocks,
chunks_load_meta=chunks_load_meta,
)
)

Expand Down
41 changes: 37 additions & 4 deletions ucm/sparse/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union, Tuple

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -117,11 +117,11 @@ def execute_begin(self, scheduler_output: SchedulerOutput):
"""
pass

def execute_finished(self):
def execute_finished(self, logits_indices :torch.Tensor) -> torch.Tensor:
"""
This is called at the end of "ModelRunner->execute_model" function.
"""
pass
return logits_indices

def attention_begin(
self,
Expand All @@ -130,8 +130,9 @@ def attention_begin(
value: torch.Tensor,
layer_name: str,
forward_context: ForwardContext,
output: Optional[torch.Tensor] = None,
phase: Optional[str] = None,
) -> None:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
This is called at the beginning of "unified_attention".
Sparse attention algorithm can modify forward_context.attn_metadata if necessary.
Expand All @@ -154,6 +155,38 @@ def attention_finished(
"""
pass

def attention_end(
self,
attn_output: torch.Tensor,
layer_name: str
) -> torch.Tensor:
"""
This is called at the end of Attention Forward.
For Blend, we "sparse" the prefill cached tokens
"""
return attn_output

def self_attention_finished(
self,
residual: torch.Tensor,
hidden_states: torch.Tensor
) -> torch.Tensor:
"""
This is called at the end of Self Attention Forward for each DecodeLayer.
"""
return residual

def layer_finished(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor
) -> torch.Tensor:
"""
This is called at the end of Self Attention Forward for each DecodeLayer.
"""
return positions


def request_finished_in_worker(self, request_id: Union[int, str]):
"""
This function releases the resources of finished requests at worker-side.
Expand Down
File renamed without changes.
Empty file added ucm/sparse/blend/__init__.py
Empty file.
Loading
Loading