forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkv_cache_manager.py
429 lines (363 loc) · 17.7 KB
/
kv_cache_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request
logger = init_logger(__name__)
class KVCacheManager:
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
max_model_len: int,
sliding_window: Optional[int] = None,
enable_caching: bool = True,
num_preallocate_tokens: int = 64,
) -> None:
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.max_model_len = max_model_len
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
self.sliding_window = sliding_window
self.enable_caching = enable_caching
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
# blocks for each request. For example, when a request reaches the end
# of its block table, we preallocate N blocks in advance. This way, we
# reduce the overhead of updating free_block_ids and ref_cnts for each
# request every step (at the cost of some memory waste).
# NOTE(woosuk): This is different from the "lookahead" slots since this
# does not guarantee that the request always has N empty blocks. After
# the request gets N empty blocks, it starts to use the blocks without
# further allocation. When it uses up all the N empty blocks, it gets
# N new empty blocks.
self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
# A Block pool of all kv-cache blocks.
self.block_pool: List[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool)
# {block_hash: {block ID: block}}. A cached block is
# a full block with a block hash that can be used for prefix caching.
# The cached block may be used by running requests or in the
# free_block_queue that could potentially be evicted.
# NOTE: We currently don't de-duplicate the blocks in the cache,
# meaning that if a block becomes full and is cached, we don't check
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
int, KVCacheBlock]] = defaultdict(dict)
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}
def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
Args:
request: The request to get the computed blocks.
Returns:
A list of blocks that are computed for the request.
"""
if not self.enable_caching:
# Prefix caching is disabled.
return []
computed_blocks = []
# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if not request.kv_block_hashes:
request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self._get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
return computed_blocks
def append_slots(
self,
request: Request,
num_tokens: int,
) -> Optional[List[KVCacheBlock]]:
"""Append slots to the block table of the request.
We first append slots to already allocated blocks. If the allocated
blocks are not enough, we allocate new blocks.
Args:
request: The request to append slots.
num_tokens: The number of tokens to append.
Returns:
A list of new blocks if new blocks are allocated, or None
if new blocks are required but cannot be allocated.
"""
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
self.block_size)
req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = num_required_blocks - len(req_blocks)
if num_new_blocks > self.free_block_queue.num_free_blocks:
# Need to allocate new blocks due to insufficient pre-allocated
# slots, but we cannot allocate new blocks due to the limit.
return None
if num_new_blocks <= 0:
# No new block is needed.
new_blocks = []
else:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
# TODO(woosuk): Check and reject requests if
# num_prompt_tokens + max_tokens > max_model_len.
self.max_num_blocks_per_req - len(req_blocks),
)
assert num_new_blocks > 0
new_blocks = self._get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
if not self.enable_caching:
return new_blocks
num_computed_full_blocks = (request.num_computed_tokens //
self.block_size)
# NOTE(rickyx): We are assuming the `num_tokens` are actual
# tokens rather than lookahead slots (e.g. for speculative decoding).
# TODO(rickyx): When supporting speculative decoding, we will need to
# differentiate between them so that we can know how many blocks are
# full after appending the actual tokens.
num_full_blocks_after_append = (request.num_computed_tokens +
num_tokens) // self.block_size
assert num_full_blocks_after_append <= len(req_blocks)
new_full_blocks = req_blocks[
num_computed_full_blocks:num_full_blocks_after_append]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=num_computed_full_blocks,
full_blocks=new_full_blocks,
prev_block=req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks >= 1 else None,
)
return new_blocks
def allocate_slots(
self,
request: Request,
num_tokens: int,
computed_blocks: List[KVCacheBlock],
) -> Optional[List[KVCacheBlock]]:
"""Allocate slots for a new request.
Args:
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks: A list of computed blocks.
Returns:
A list of new allocated blocks.
"""
if num_tokens == 0:
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
if blk.ref_cnt == 0)
num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
# Cannot allocate new blocks.
return None
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self._touch(computed_blocks)
else:
assert not computed_blocks, (
"Computed blocks should be empty when "
"prefix caching is disabled")
# Determine the number of new blocks to allocate considering
# preallocated blocks.
num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
# TODO(woosuk): Check and reject requests if
# num_prompt_tokens + max_tokens > max_model_len.
self.max_num_blocks_per_req - len(computed_blocks),
)
assert num_new_blocks > 0
# Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks)
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks
if not self.enable_caching:
return new_blocks
num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
new_full_blocks = self.req_to_blocks[
request.request_id][len(computed_blocks):num_full_blocks]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None,
)
return new_blocks
def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
When caching is enabled, we free the blocks in reverse order so that
the tail blocks are evicted first.
Args:
request: The request to free the blocks.
"""
# Default to [] in case a request is freed (aborted) before alloc.
blocks = self.req_to_blocks.pop(request.request_id, [])
ordered_blocks: Iterable[KVCacheBlock] = blocks
if self.enable_caching:
# Free blocks in reverse order so that the tail blocks are
# freed first.
ordered_blocks = reversed(blocks)
for block in ordered_blocks:
block.decr_ref()
if block.ref_cnt == 0:
self.free_block_queue.append(block)
def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
Returns:
A list of new block.
"""
if num_blocks > self.free_block_queue.num_free_blocks:
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
ret: List[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
# First allocate blocks.
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# If the block is cached, evict it.
if self.enable_caching:
self._evict_cached_block(curr_block)
curr_block.incr_ref()
ret.append(curr_block)
idx += 1
return ret
def _evict_cached_block(self, block: KVCacheBlock) -> None:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.
Args:
block: The block to evict.
"""
block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block:
block.reset_hash()
del self.cached_block_hash_to_block[block_hash][block.block_id]
if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]
def _get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
Returns:
The cached block if it exists, or None.
"""
if block_hash in self.cached_block_hash_to_block:
first_block_id = list(
self.cached_block_hash_to_block[block_hash].keys())[0]
return self.cached_block_hash_to_block[block_hash][first_block_id]
return None
def _touch(self, blocks: List[KVCacheBlock]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Args:
blocks: A list of blocks to touch.
"""
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0:
self.free_block_queue.remove(block)
block.incr_ref()
def _cache_full_blocks(
self,
request: Request,
blk_start_idx: int,
full_blocks: List[KVCacheBlock],
prev_block: Optional[KVCacheBlock],
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `blk_start_idx` to the end
of the request's full blocks, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
Args:
request: The request to cache the blocks.
blk_start_idx: The index of the first block in the request's blocks
to cache.
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
num_cached_block_hashes = len(request.kv_block_hashes)
# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
if prev_block is not None:
# Previous block must have a block hash because it must be
# a full, cached block.
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i
if blk_idx < num_cached_block_hashes:
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
start_token_idx = blk_idx * self.block_size
end_token_idx = (blk_idx + 1) * self.block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)
# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash_value = block_hash.hash_value