-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[TRTLLM-6371][feat] Restructure C++ KVCacheManager to better handle limited attention layers #7510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
9cc92f6
28d6dd1
8a06a28
f30fa42
9d8c420
6ceed6b
64f4504
5eaf9d5
800b05d
21643e3
e0499cf
4840a98
8647828
91445c5
76f317e
ca48af9
3399798
78bcf4d
8836822
86ed0f8
42fe938
cef3334
ad2aa8f
cfe0609
02abe39
f025cfb
735fe3c
a678b91
787b681
50a6f3c
a487b55
78c6253
bb421b9
f3e5c13
d427948
d57c3b4
3f696f6
c8177d3
3802864
e3a1921
24dfa89
61f9702
984a4d2
23c28fb
1014a3e
a8ad970
39e81e1
64ef5f3
fd364d9
3f22026
9367bbe
98c0d29
a63d0b6
12da9ae
f2c1d9a
0ee0004
b98a2fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -207,47 +207,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst, | |
| } | ||
| } | ||
|
|
||
| // | ||
| // Note about recording events to wait for cudaMempyAsync calls between blocks: | ||
| // The memory copy involves raw memory blocks, which are pointed to by the | ||
| // memory pool block index. When recording events, you must use getMemoryPoolBlockIndex() | ||
| // as the raw memory block identifier. Earlier versions of this code used getBlockId() | ||
| // when recording events, this is just wrong. getBlockId() returns the logical block id, | ||
| // which has nothing to do with the raw memory block pointers involved in a cudaMemcpy. | ||
| // | ||
|
|
||
| // | ||
| // Notes about need for synchronization: | ||
| // | ||
| // Earlier versions of this code relied on decoder implicitly syncing GPU with CPU. | ||
| // This is inherently dangerous, it is not given that decoder will always explicitly sync | ||
| // GPU with CPU for every step, a major design goal of ongoing work is to avoid this. | ||
| // To make the code future proof, we introdduce a new method SyncWithBufferManager() | ||
| // that ensures that internal copy streams will wait for prefill and decode kernels | ||
| // that have already been scheduled. | ||
| // | ||
| // Earlier versions of this code did not account for all possible cases were a new block copy | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| // needed to wait for a previously scheduled copy to finish. For instance, it is possible | ||
| // that two primary blocks are offloaded to the same secondary block in a single step, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rightfully the secondary block provided by the policy should have the popped block removed from the secondary free block queue. Should the problem be inside eviction policy that this happens? |
||
| // scheduling the second offloading without waiting for the first one to finish leads to | ||
| // a corrupted block after offloading. It is possible that partial reuse will copy | ||
| // from a block that is currently being onboarded, scheduling the partial copy without | ||
| // waiting for the onboarding to finish will lead to a corrupted block. To handle all | ||
| // possible cases needing synchronization we record separate events for reads and writes | ||
| // to a block. When a new block copy is scheduled, we wait for all writes to the source | ||
| // block and all reads and writes to a destination block. | ||
| // | ||
| // As before, syncTransfers() must be called after last call to KVCacheManager::addSequence. | ||
| // Failing to do so will lead to corrupted blocks eventually. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have a mechanism to fool-proof this? |
||
| // | ||
|
|
||
| void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block, | ||
| std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, | ||
| std::string const& directory) | ||
| { | ||
| if (mode != executor::KvCacheTransferMode::DRAM | ||
| && mPendingOffloads.find(offloadBlock->getBlockId()) == mPendingOffloads.end()) | ||
| // Wait for any pending writes before reading from offloadBlock | ||
| auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex()); | ||
| if (offloadBlockPendingWriteItr != mPendingWrites.end()) | ||
| { | ||
| TLLM_LOG_DEBUG("Skipping onboard for block %d because it was never previously offloaded to disk", | ||
| offloadBlock->getBlockId()); | ||
| return; | ||
| mOnboardManager.getStream().wait(offloadBlockPendingWriteItr->second); | ||
| // Don't erase, we are not changing state of offloadBlock | ||
| } | ||
|
|
||
| if (mPendingOffloads.find(offloadBlock->getBlockId()) != mPendingOffloads.end()) | ||
| // Wait for any pending reads before overwriting block | ||
| auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex()); | ||
| if (blockPendingReadItr != mPendingReads.end()) | ||
| { | ||
| mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]); | ||
| mOnboardManager.getStream().wait(blockPendingReadItr->second); | ||
| mPendingReads.erase(blockPendingReadItr); | ||
| } | ||
| // Wait for any pending writes before overwriting block | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to have a pending read depnding on the pending write here? |
||
| auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex()); | ||
| if (blockPendingWriteItr != mPendingWrites.end()) | ||
| { | ||
| mOnboardManager.getStream().wait(blockPendingWriteItr->second); | ||
| mPendingWrites.erase(blockPendingWriteItr); | ||
| } | ||
|
|
||
| copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory); | ||
|
|
||
| // Record new pending read from offloadBlock | ||
| mPendingReads[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); | ||
| mOnboardManager.getStream().record(mPendingReads[offloadBlock->getMemoryPoolBlockIndex()]); | ||
| // Record new pending write to block | ||
| mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent(); | ||
| mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]); | ||
| } | ||
|
|
||
| void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock, | ||
| std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, | ||
| std::string const& directory) | ||
| { | ||
| mPendingOffloads[block->getBlockId()] = tr::CudaEvent(); | ||
| // Wait for any pending writes before reading from block | ||
| auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex()); | ||
| if (blockPendingWriteItr != mPendingWrites.end()) | ||
| { | ||
| mOffloadManager.getStream().wait(blockPendingWriteItr->second); | ||
| // Don't erase, we are not changing state of block | ||
| } | ||
| // Wait for any pending reads before overwriting offloadBlock | ||
| auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex()); | ||
| if (offloadBlockPendingReadItr != mPendingReads.end()) | ||
| { | ||
| mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second); | ||
| mPendingReads.erase(offloadBlockPendingReadItr); | ||
| } | ||
| // Wait for any pending writes before overwriting offloadBlock | ||
| auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex()); | ||
| if (offloadBlockPendingWriteItr != mPendingWrites.end()) | ||
| { | ||
| mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second); | ||
| mPendingWrites.erase(offloadBlockPendingWriteItr); | ||
| } | ||
|
|
||
| copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory); | ||
| mOffloadManager.getStream().record(mPendingOffloads[block->getBlockId()]); | ||
|
|
||
| // Record new pending read from block | ||
| mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent(); | ||
| mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]); | ||
| // Record new pending write to offloadBlock | ||
| mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); | ||
| mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]); | ||
| } | ||
|
|
||
| void KVCacheTransferManager::syncWithBufferManager() | ||
| { | ||
| tr::CudaEvent readyForOffloadEvent; | ||
| mBufferManager.getStream().record(readyForOffloadEvent); | ||
| mOffloadManager.getStream().wait(readyForOffloadEvent); | ||
|
|
||
| tr::CudaEvent readyForOnboardEvent; | ||
| mBufferManager.getStream().record(readyForOnboardEvent); | ||
| mOnboardManager.getStream().wait(readyForOnboardEvent); | ||
|
|
||
| // Once we synchronize, clear our list of pending thransfers. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type-o, transfer |
||
| mPendingReads.clear(); | ||
| mPendingWrites.clear(); | ||
| } | ||
|
|
||
| void KVCacheTransferManager::syncTransfers() | ||
| { | ||
| tr::CudaEvent offloadEvent; | ||
| mOffloadManager.getStream().record(offloadEvent); | ||
| mBufferManager.getStream().wait(offloadEvent); | ||
|
|
||
| tr::CudaEvent onboardEvent; | ||
| mOnboardManager.getStream().record(onboardEvent); | ||
|
|
||
| mBufferManager.getStream().wait(offloadEvent); | ||
| mBufferManager.getStream().wait(onboardEvent); | ||
|
|
||
| // Once we synchronize, clear our list of pending thransfers. | ||
| mPendingOffloads.clear(); | ||
| mPendingReads.clear(); | ||
| mPendingWrites.clear(); | ||
| } | ||
|
|
||
| } // namespace tensorrt_llm::batch_manager::kv_cache_manager | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type-o,
introdduce