diff --git a/cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h b/cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h index affa83279b7..6b1b21ac8d6 100644 --- a/cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h +++ b/cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h @@ -54,6 +54,9 @@ class BaseEvictionPolicy virtual void refresh() = 0; virtual bool verifyQueueIntegrity() = 0; + + /// @brief Print block Id of all blocks waiting in free queues + virtual std::string printFreeQueues() const = 0; }; struct ExpiringBlockComparator @@ -91,6 +94,8 @@ class LRUEvictionPolicy : public BaseEvictionPolicy bool verifyQueueIntegrity() override; + std::string printFreeQueues() const override; + private: // Queues of available leaf blocks, split by cache level and priority level std::vector> mFreeQueues; diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 0d8f7aa0e13..54097ed1a2d 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -63,15 +63,21 @@ static constexpr SizeType32 kSecondaryLevel = 1; static constexpr SizeType32 kSWAExtraBlock = 1; class KVCacheBlock; +class KVCachePromptLookupNode; +class KVCachePromptLookup; class BlockManager; class KVCacheManager; class KVCacheTransferManager; +class WindowBlockManager; +class GenerationRequest; using SizeType32 = tensorrt_llm::runtime::SizeType32; using TokenIdType = tensorrt_llm::runtime::TokenIdType; using VecTokens = std::vector; using BeamTokens = std::vector; using BlockPtr = std::shared_ptr; +using LookupNodePtr = std::shared_ptr; +using LookupPtr = std::shared_ptr; using FreeBlocksQueue = std::list; using UniqueToken = tensorrt_llm::runtime::UniqueToken; using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; @@ -193,6 +199,32 @@ struct BlockKey } return numMatched; } + + //! \brief Deep copy, optionally reducing number of tokens + struct BlockKey clone(int newNumberOfTokens = 0) const + { + BlockKey blockKey; + blockKey.usesExtraIds = usesExtraIds; + blockKey.loraTaskId = loraTaskId; + if (newNumberOfTokens > 0) + { + // Reduce token length (for partial matching) + TLLM_CHECK_WITH_INFO(newNumberOfTokens <= static_cast(uniqueTokens.size()), + "newNumberOfTokens = %d must be <= uniqueTokens.size() = %d", newNumberOfTokens, + static_cast(uniqueTokens.size())); + blockKey.uniqueTokens.insert( + blockKey.uniqueTokens.begin(), uniqueTokens.begin(), uniqueTokens.begin() + newNumberOfTokens); + } + else + { + // Copy all tokens + blockKey.uniqueTokens.insert(blockKey.uniqueTokens.begin(), uniqueTokens.begin(), uniqueTokens.end()); + } + blockKey.extraKeys.insert(blockKey.extraKeys.begin(), extraKeys.begin(), extraKeys.end()); + blockKey.cacheSaltID = cacheSaltID; + return blockKey; + // TODO: Add unit test verifying correct copy of both partial and full token Ids. + } }; std::vector buildBlockKeys(std::list& blockedUniqueTokens, LlmRequest const& llmRequest); @@ -210,7 +242,7 @@ struct BlockKeyHasher } }; -using NextBlockMap = std::unordered_map; +using NextNodeMap = std::unordered_map; struct KvCacheStats { @@ -238,6 +270,179 @@ struct KvCacheStats std::size_t allocatedBytes{}; }; +using LookupResult = std::vector>; + +// Vector of LookupResult, one for each BlockKey used during search. +// If no match was found, vector will be empty. +// If an exact match was found, vector will have one item. +// If partial matching is enabled and no exact match was found, +// vector will list all nodes with at least one matching token. +// Partially matching nodes are sorted in descending order of number of matching tokens. +using LookupResults = std::vector; + +// Implement an object that represents a given prompt prefix in search structure. +// The node contains pointers to all reusable state for the prompt prefix. +class KVCachePromptLookupNode +{ +public: + explicit KVCachePromptLookupNode(BlockKey const& blockKey, bool isFull, bool isRoot = false); + + void setBlockKey(BlockKey const& blockKey, bool isFull); + + BlockKey getBlockKey() const; + + [[nodiscard]] VecUniqueTokens const& getUniqueTokens() const; + + [[nodiscard]] bool isRoot() const; + + LookupNodePtr const& getPrevNode() const; + + void setPrevNode(LookupNodePtr prevNode); + + [[nodiscard]] NextNodeMap getNextNodes() const; + + void addNextNode(BlockKey const& blockKey, LookupNodePtr block); + + void removeNextNode(BlockKey const& blockKey); + + //! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of + //! blockKey. + //! @return tuple of [partialMatch, numMatched, block], partialMatch is true if not all the tokens of the block were + //! matched. + [[nodiscard]] LookupResult findMatchingNodes( + BlockKey const& blockKey, bool enablePartialReuse, bool ignoreNodesWithoutBlocks) const; + + void setBlock(SizeType32 windowSize, BlockPtr block); + + [[nodiscard]] BlockPtr getBlock(SizeType32 windowSize) const; + + [[nodiscard]] bool hasBlocks() const; + + [[nodiscard]] bool isFull() const; + + [[nodiscard]] bool isLeaf() const; + + [[nodiscard]] bool canBeDeleted() const; + + [[nodiscard]] std::vector getFullBlockKey() const; + + void deleteNodeIfPossible(); + + void printSearchTree(std::ostream& os, SizeType32 indent) const; + +private: + // Key of this block in mNextBlocks map in block pointed to by mPrevBlock + BlockKey mBlockKey; + // Flag indicating if block is full + bool mIsFull; + // Flag indicating if this is root node + bool mIsRoot; + // Previous node in search structure + LookupNodePtr mPrevNode; + // Next node(s) in sequence(s) + NextNodeMap mNextNodes; + // Pointers to blocks holding KV state for this prompt prefix + std::unordered_map mBlocks; +}; + +// Print methods used for debugging. +std::ostream& operator<<(std::ostream& out, LookupResult const& match); +std::ostream& operator<<(std::ostream& out, LookupResults const& matches); +std::ostream& operator<<(std::ostream& out, std::tuple const& match); +std::ostream& operator<<( + std::ostream& out, std::vector> const& matches); +std::ostream& operator<<(std::ostream& out, + std::unordered_map>> const& matches); + +template +std::string streamPrint(T v) +{ + std::stringstream out; + out << v; + return out.str(); +} + +template +std::string streamPrint(C callable, Args... remainingArgs) +{ + std::stringstream out; + out << callable(out, remainingArgs...); + return out.str(); +} + +// +// Basic building block of KV cache. +// +// Implements a radix tree that is used to search for blocks containing reusable state for a common prefix. +// There is only one instance of the search tree, it is owned by BlockManager. +// Each node in the search tree corresponds to a particular prefix. +// Each node has a map from window size to KVCacheBlock, if reusable state exists for a given window size, +// the map will have a pointer to the block containing it. +// KVCacheBlocks have a pointer back to the node that is referring to them. This makes it easier to remove +// blocks from the search tree and also allows searches like getPrevBlock(). +// +// Notes: +// None of these methods are thread safe. It is up to the caller to ensure thread safety. +// These methods return values for all window sizes with a single pass through the search tree. +// +class KVCachePromptLookup +{ +public: + explicit KVCachePromptLookup(CacheType cacheType, SizeType32 tokensPerBlock); + + //! \brief Return vector of BlockKey for the first inputLength tokens of prompt stored in llmRequest. + //! \details If inputLength < 0, effective input length is total number of tokens + inputLength. If you want to skip + //! the last token, you can simply do inputLength = -1. + [[nodiscard]] std::vector getBlockKeys( + LlmRequest const& llmRequest, SizeType32 inputLength, bool allowPartiallyFilledBlock) const; + + //! \brief Find last valid block for prefix given in blockKey. + //! \details Seems like blockKey contains a full prefix, not just key for an individual block. This is very hacky + //! and will lead to hilarious bugs. \details Since all KV cache manager bugs eventually gets reassigned to me, THIS + //! MUST BE FIXED AT ALL COSTS. \details We should introduce a new data type for full prefix keys (maybe named + //! FullPrefixKey?). \details Alternatively, we can introduce a new field that tracks whether BlockKey object + //! contains a full prefix or just key for one block. + std::unordered_map> findBlocksInReuseTreeByBlockKey( + std::vector const& windowSizes, BlockKey const& blockKey) const; + + //! \brief Find first new context block for each window block manager. + //! \param llmRequest The new request. + //! \param inputLength Number of useful prompt tokens. If zero, length of prompt minus 1 is used. + //! \param allowPartiallyFilledBlock Allow matching of blocks that are not full. + //! \param windowBlockManagers Map of window block managers vs window size. Method will search for a new context + //! block for each window size. \return map of BlockKey vs windowSize. The block key is that of first new context + //! block for that window size. + [[nodiscard]] std::unordered_map findNewContextBlock(LlmRequest const& llmRequest, + SizeType32 inputLength, bool allowPartiallyFilledBlock, std::vector const& windowSizes) const; + + //! \brief Find matching nodes for a given prompt prefix. + //! \details Nodes are created if not found. + //! \param allowPartiallyFilledBlock Allow last block in prompt to have less than tokensPerBlock tokens. + [[nodiscard]] LookupResults lookup( + LlmRequest const& llmRequest, SizeType32 inputLength, bool allowPartiallyFilledBlock); + + //! \brief Find matching blocks for a given prompt prefix for all window sizes. + //! \details return map of matching blocks vs window size. Matching blocks is a vector of varying size. + std::unordered_map>> lookupBlocks( + std::map const& windowBlockManagers, LlmRequest const& llmRequest, + SizeType32 inputLength, bool allowPartiallyFilledBlock, bool enablePartialReuse) const; + + //! \brief Print input prompt given by llmRequest. + //! \details Uses some member variables for formatting, hence cannot be made static. + std::string printPrompt(LlmRequest const& llmRequest) const; + + //! \brief Print search tree. + std::string printSearchTree() const; + +private: + // Root of search structure + LookupNodePtr mRoot; + // KV cache type (self or cross) + CacheType mCacheType; + // Number of tokens per one block + SizeType32 mTokensPerBlock; +}; + // Basic building block of a paged KV cache - a single // cache block. This class just holds metadata, no pointers // since it is reused across all layers. @@ -248,14 +453,12 @@ class KVCacheBlock static constexpr IdType kCachedBlocksRootId = -1; - explicit KVCacheBlock(IdType blockId, kernels::KVCacheIndex blockIdx); + explicit KVCacheBlock(IdType blockId, kernels::KVCacheIndex blockIdx, SizeType32 windowSize); void startScheduling(); [[nodiscard]] IdType getBlockId() const; - [[nodiscard]] NextBlockMap getNextBlocks() const; - [[nodiscard]] kernels::KVCacheIndex::UnderlyingType getMemoryPoolBlockIndex() const; [[nodiscard]] bool isPrimary() const; @@ -272,40 +475,26 @@ class KVCacheBlock [[nodiscard]] bool hasSchedulingRefs() const; + // This info is duplicated in KVCacheBlock and KVCachePromptLookupNode + // because it is needed by the former when KVCacheBlock might not be stored + // in lookup structure and therefore cannot get this value from there void setBlockKey(BlockKey const& blockKey, bool isFull); - - BlockKey getBlockKey(); - + BlockKey getBlockKey() const; [[nodiscard]] VecUniqueTokens const& getUniqueTokens() const; - BlockPtr const& getPrevBlock() const; - - void setPrevBlock(BlockPtr prevBlock); - BlockPtr const& getPrevBlockInSeq() const; - void setPrevBlockInSeq(BlockPtr prevBlock); - void addNextBlock(BlockKey const& blockKey, BlockPtr block); - - void removeNextBlock(BlockKey const& blockKey); - - //! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of - //! blockKey. - //! @return tuple of [partialMatch, numMatched, block], partialMatch is true if not all the tokens of the block were - //! matched. - [[nodiscard]] std::tuple findMatchingBlock( - BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const; - - //! \brief Free block from previous block if present. - void freeLeafBlock(); + //! \brief Return previous block in search tree if this block is stored in search tree. + //! \details Returns nullptr if this block is not in search tree or it is at root level. + //! \details Previously this function would have returned pointer to a special 'root' block if already at root + //! level. + BlockPtr getPrevBlock() const; [[nodiscard]] bool isFull() const; [[nodiscard]] bool isShared() const; - [[nodiscard]] bool isLeaf() const; - void setPriority(executor::RetentionPriority priority); [[nodiscard]] executor::RetentionPriority getPriority() const; @@ -325,6 +514,21 @@ class KVCacheBlock size_t getHash() const; + // attach to lookup node (register block for reuse) + void attachToLookupNode(LookupNodePtr lookupNode, BlockPtr block); + + // detach from lookup node (unregister block for reuse) + void detachFromLookupNode(); + + // get lookup node using this block. Can be nullptr + [[nodiscard]] LookupNodePtr getLookupNode() const; + + //! \brief Check if block is still valid for reuse. + [[nodiscard]] bool isValidForReuse() const + { + return mLookupNode != nullptr; + } + private: // Linear ID of block independent of pool IdType mBlockId; @@ -342,15 +546,9 @@ class KVCacheBlock // Key of this block in mNextBlocks map in block pointed to by mPrevBlock BlockKey mBlockKey; - // Previous block in reuse tree, or nullptr if not reusing - BlockPtr mPrevBlock; - // Previous block in sequence, == nullptr for first block, == mPrevBlock if reusing and not first BlockPtr mPrevBlockInSeq; - // Next block(s) in sequence(s) - NextBlockMap mNextBlocks; - // Iterator pointing to this block in mFreeBlocks. std::optional mFreeBlockIterator; @@ -365,6 +563,11 @@ class KVCacheBlock std::optional mExpirationTime; // Hash for the event manager size_t mHash; + + // Pointer to search tree lookup node using this block + LookupNodePtr mLookupNode; + // Window size using this block (0 if not in use) + SizeType32 mWindowSize; }; class GenerationRequest @@ -604,8 +807,9 @@ class WindowBlockManager void startScheduling(); //! \brief Assign blocks for new sequence. Try to reuse blocks. - void addSequence( - GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest); + void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, + LlmRequest& llmRequest, + std::vector> const& matchedBlocks); //! \brief Assign blocks for new sequence. Does not try to reuse blocks. void addSequence(GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock); @@ -625,9 +829,7 @@ class WindowBlockManager void pinBlocks(GenerationRequest& sequence); //! \brief Release blocks of the sequence. - //! \details When llmRequest is provided and reuse is enabled, blocks will be stored. - std::optional releaseBlocks( - GenerationRequest& sequence, OptionalRef llmRequest); + std::optional releaseBlocks(GenerationRequest& sequence); //! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); @@ -799,16 +1001,14 @@ class WindowBlockManager void offloadBlock(BlockPtr const& block, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); - //! \brief Find first new block that must be allocated for context phase and return it's concatenated token vectors. - //! \details Only full blocks are considered. - [[nodiscard]] std::optional findNewContextBlock( - VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const; - [[nodiscard]] runtime::BufferManager const& getBufferManager() const { return mBufferManager; } + //! \brief Sync internal streams used by transfer manager with buffer manager stream + void syncTransferManagerWithBufferManager(); + //! \brief Perform per-request bookkeeping void refreshBlocks(); @@ -820,8 +1020,7 @@ class WindowBlockManager //! \param pinBlocks If true, increment ref count for blocks while storing (pin on store). //! \return Pair of (num blocks stored for reuse, id of the last block stored if any). [[nodiscard]] std::pair> storeBlocks( - std::vector const& blockKeys, std::vector const& blockIds, - bool pinBlocks = false); + LookupResults const& lookupNodes, std::vector const& blockIds, bool pinBlocks = false); [[nodiscard]] bool verifyQueueIntegrity(); @@ -849,27 +1048,10 @@ class WindowBlockManager return mIsSWA; } - [[nodiscard]] std::shared_ptr findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey); - //! \brief Unpin blocks by starting from a block id and walking prev pointers. void unpinBlocksById(KVCacheBlock::IdType blockId); - void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId) - { - mIsValidStoreForReuseSequence[requestId] = true; - } - - void releaseSequenceStorageValidity(LlmRequest::RequestIdType requestId) - { - mIsValidStoreForReuseSequence.erase(requestId); - } - - //! \brief Return whether this sequence is valid for store for reuse - [[nodiscard]] bool isSequenceValidForStoreForReuse(LlmRequest::RequestIdType requestId) const - { - TLLM_CHECK_WITH_INFO(mIsValidStoreForReuseSequence.count(requestId) > 0, "Sequence should be bookkeeped"); - return mIsValidStoreForReuseSequence.at(requestId); - } + [[nodiscard]] std::string printFreeQueues() const; private: //! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq. @@ -882,23 +1064,19 @@ class WindowBlockManager //! \param blockKeys Key of each block. //! \param sequence Sequence to which blocks are assigned. //! \return Number of matched tokens from loaded blocks. - SizeType32 loadOrAllocateBlocks(std::vector const& blockKeys, SizeType32 numContextBlocks, - GenerationRequest& sequence, std::vector const& perBlockRetentions, + SizeType32 loadOrAllocateBlocks( + std::vector> const& matchedBlocks, + SizeType32 numContextBlocks, GenerationRequest& sequence, + std::vector const& perBlockRetentions, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); - //! \brief Free block and all it's descendants. This makes block a claimed leaf block. - void freeChildren(BlockPtr const& block); - //! \brief Find block least likely to be reused, free it if necessary and return. //! \param sequence Sequence which the free block is allocated for - [[nodiscard]] BlockPtr getFreeBlock(GenerationRequest& sequence, + [[nodiscard]] BlockPtr getFreeBlock( executor::RetentionPriority = executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::optional durationMs = std::nullopt, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); - //! \brief Calls KVCacheBlock::freeLeafBlock to remove block from search tree. - void freeLeafBlock(BlockPtr const& block); - //! \brief For FP4 quantization. Creates pool objects for FP4 block scalars. void createBlockScalePools(SizeType32 blockSize); @@ -935,8 +1113,6 @@ class WindowBlockManager bool mIsSWA; // List of all blocks by idx std::vector mAllBlocksById; - // Dummy block acting as root for BlockToken searches - BlockPtr mCachedBlocksRoot; // KV cache type (self or cross) CacheType mCacheType; // Eviction Policy @@ -975,17 +1151,6 @@ class WindowBlockManager // The kv cache connector manager std::shared_ptr mKvCacheConnectorManager; - // Mutex for the cached blocks root - std::mutex mCachedBlocksRootMutex; - - // Record which sequence is using the block - std::map mBlockToSequence; - // Record whether a sequence has all blocks held valid. - // The boolean value is set to true upon first encounter of a new sequence. - // It may be invalidated to false when other sequence acquires a block that - // is used by another sequence. - std::map mIsValidStoreForReuseSequence; - // Whether to enable indexer K cache bool mEnableIndexerKCache; // Quant block size for indexer K cache @@ -1077,10 +1242,10 @@ class BlockManager executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); [[nodiscard]] std::pair> storeBlocks( - std::vector const& blockKeys, std::vector const& blockIds, - SizeType32 windowSize, bool pinBlocks = false) + LookupResults const& lookupNodes, std::vector const& blockIds, SizeType32 windowSize, + bool pinBlocks = false) { - return mWindowBlockManagers.at(windowSize).storeBlocks(blockKeys, blockIds, pinBlocks); + return mWindowBlockManagers.at(windowSize).storeBlocks(lookupNodes, blockIds, pinBlocks); } [[nodiscard]] bool verifyQueueIntegrity(SizeType32 windowSize); @@ -1255,10 +1420,7 @@ class BlockManager } [[nodiscard]] std::shared_ptr findBlocksInReuseTreeByBlockKey( - BlockKey const& blockKey, SizeType32 windowSize) - { - return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByBlockKey(blockKey); - } + BlockKey const& blockKey, SizeType32 windowSize) const; [[nodiscard]] SizeType32 getNumPrimaryBlocks() const { @@ -1276,6 +1438,9 @@ class BlockManager //! \brief Store newest block for reuse void storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest); + //! \brief Sync internal streams used by transfer manager with buffer manager stream + void syncTransferManagerWithBufferManager(); + //! \brief Perform per-request bookkeeping void refreshBlocks(); @@ -1305,46 +1470,17 @@ class BlockManager //! context block that goes OOW. void adjustBlocksIfNeeded(GenerationRequest& sequence); - //! \brief Return whether the sequence is already managed by the block manager - [[nodiscard]] bool isSequenceHeld(LlmRequest::RequestIdType requestId) const + //! \brief Print free queues maintained by eviction policy + //! \details This method is meant for debugging + [[nodiscard]] std::string printFreeQueues(SizeType32 windowSize) const { - return mManagedSequences.count(requestId) > 0; + return mWindowBlockManagers.at(windowSize).printFreeQueues(); } - //! \brief Add a sequence to the managed sequences - //! \details Take the sequence into account for the manager. Initialize - //! sequence storage validity under all window sizes. - void holdSequence(LlmRequest::RequestIdType requestId) + //! \brief Print search tree. + [[nodiscard]] std::string printSearchTree() const { - mManagedSequences.insert(requestId); - for (auto const& [windowSize, metadata] : mWindowSizeToMetadata) - { - mWindowBlockManagers.at(windowSize).initializeSequenceStorageValidity(requestId); - } - } - - //! \brief Remove a sequence from the managed sequences. - //! \details Remove sequence from the managed sequences and remove sequence - //! storage - void releaseSequence(LlmRequest::RequestIdType requestId) - { - mManagedSequences.erase(requestId); - for (auto const& [windowSize, metadata] : mWindowSizeToMetadata) - { - mWindowBlockManagers.at(windowSize).releaseSequenceStorageValidity(requestId); - } - } - - //! \brief Return whether the sequence is still valid for store-for-reuse - //! regarding the specific window size. - //! \details Currently this utility function is only used under - //! kvCacheManagerTest.cpp. Checking for store-for-reuse for each window - //! size is done in an iterating fashion under BlockManager::releaseBlocks. - bool isSequenceValidForStoreForReuse(LlmRequest::RequestIdType requestId, SizeType32 windowSize) const - { - TLLM_CHECK_WITH_INFO( - mWindowBlockManagers.count(windowSize) > 0, "Querying window size is not found under mWindowBlockManager"); - return mWindowBlockManagers.at(windowSize).isSequenceValidForStoreForReuse(requestId); + return mLookup->printSearchTree(); } private: @@ -1365,6 +1501,9 @@ class BlockManager return getWindowSizeMetadata(windowSize).absolutePoolsOffset; } + std::optional notThreadSafeStoreBlocksForReuse( + GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks); + private: SizeType32 mNumLayers; SizeType32 mTokensPerBlock; @@ -1381,8 +1520,11 @@ class BlockManager std::vector mLayerToWindowSize; std::vector mAbsolutePoolToWindowSize; std::vector mAbsolutePoolToRelativePoolIndex; - // Record what sequences are currently managed by the block manager - std::set mManagedSequences; + + bool mEnablePartialReuse; + // Mutex for the cached blocks root + mutable std::mutex mCachedBlocksRootMutex; + LookupPtr mLookup; }; struct OffsetTableDimensions @@ -1531,6 +1673,7 @@ class BaseKVCacheManager [[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0; [[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0; + virtual void syncTransferManagerWithBufferManager() = 0; virtual void refreshBlocks() = 0; virtual void flushIterationEvents() = 0; @@ -1896,6 +2039,11 @@ class KVCacheManager : public BaseKVCacheManager return mBlockManager.getPoolLayerIdx(layer_idx); } + void syncTransferManagerWithBufferManager() override + { + mBlockManager.syncTransferManagerWithBufferManager(); + } + //! \brief Perform per-iteration bookkeeping void refreshBlocks() override { diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h index 45f615cafe7..00540dc671e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h @@ -46,7 +46,15 @@ class KVCacheTransferManager int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); - //! \brief Synchronize the offload/onboard streams with the bufferManager stream. + //! \brief Synchronize internal streams with bufferManager stream. + //! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the + //! internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing + //! any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step. + void syncWithBufferManager(); + + //! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode + //! kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method + //! must be called after last call to KVCacheManager::addSequence in every step. void syncTransfers(); private: @@ -75,8 +83,10 @@ class KVCacheTransferManager runtime::BufferManager mOnboardManager; runtime::BufferManager mOffloadManager; - // Track the block ids offloaded in this iteration. - std::unordered_map mPendingOffloads; + // Track reads and writes for blocks. Note that it is the memory pool index that + // identifies the raw memory blocks involved in I/O, not the block Id. + std::unordered_map mPendingReads; + std::unordered_map mPendingWrites; // Reference to parent loopback agent std::shared_ptr mLoopbackAgent; int mDeviceId; diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 740badd6370..5adaa1c8351 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -547,6 +547,32 @@ struct RetentionPriorityAndDuration std::optional retentionPriority; std::optional durationMs; + + std::string print() + { + std::stringstream out; + out << "RetentionPriorityAndDuration{"; + bool needComma = false; + if (retentionPriority.has_value()) + { + out << "retentionPriority=" << retentionPriority.value(); + needComma = true; + } + if (durationMs.has_value()) + { + if (needComma) + { + out << ","; + } + else + { + needComma = true; + } + out << "durationMs=" << durationMs.value().count(); + } + out << "}"; + return out.str(); + } }; /// @brief Configuration for the request's retention in the KV Cache @@ -580,6 +606,24 @@ class KvCacheRetentionConfig /// have no expiration time, and keep the block at the given priority level until it gets reclaimed. After the /// duration has passed, the block will be moved back to the `kDefaultRetentionPriority` level. std::optional durationMs; + + std::string print() const + { + std::stringstream out; + out << "TokenRangeRetentionConfig={"; + out << "tokenStart=" << tokenStart; + if (tokenEnd.has_value()) + { + out << ",tokenEnd=" << tokenEnd.value(); + } + out << ",priority=" << priority; + if (durationMs.has_value()) + { + out << ",durationMs=" << durationMs.value().count(); + } + out << "}"; + return out.str(); + } }; explicit KvCacheRetentionConfig() @@ -611,6 +655,27 @@ class KvCacheRetentionConfig && mDirectory == other.mDirectory; } + std::string print() const + { + std::stringstream out; + out << "KvCacheRetentionConfig={"; + bool firstIteration = true; + for (auto trrc : mTokenRangeRetentionConfigs) + { + if (firstIteration) + { + firstIteration = false; + } + else + { + out << ","; + } + out << trrc.print(); + } + out << "}"; + return out.str(); + } + private: /// @brief The token ranges and priority levels to update. Ranges must be non-overlapping. For example [(0, 64), /// (100, 128), (70, 80)] is valid, whereas diff --git a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp index c0482deb554..211abe78186 100644 --- a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp +++ b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp @@ -26,6 +26,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(allocateKvCache); + kvCacheManager.syncTransferManagerWithBufferManager(); + for (auto const& llmReq : contextRequests) { if (llmReq->isFirstContextChunk()) diff --git a/cpp/tensorrt_llm/batch_manager/evictionPolicy.cpp b/cpp/tensorrt_llm/batch_manager/evictionPolicy.cpp index 45f6522a509..fef2a65498a 100644 --- a/cpp/tensorrt_llm/batch_manager/evictionPolicy.cpp +++ b/cpp/tensorrt_llm/batch_manager/evictionPolicy.cpp @@ -103,6 +103,35 @@ bool LRUEvictionPolicy::verifyQueueIntegrity() return !queueCompromised; } +std::string LRUEvictionPolicy::printFreeQueues() const +{ + std::stringstream os; + os << "Free queues:" << std::endl; + for (SizeType32 cacheLevel = 0; cacheLevel < 2; cacheLevel++) + { + switch (cacheLevel) + { + case 0: os << " Primary" << std::endl; break; + case 1: os << " Secondary" << std::endl; break; + default: throw "Unknown cacheLevel"; + } + for (SizeType32 level = 0; level < kMaxPriority - kMinPriority + 1; level++) + { + auto const& blocks = mFreeQueues[cacheLevel][level]; + if (!blocks.empty()) + { + os << " " << level << std::setw(3) << " : "; + for (auto const& block : blocks) + { + os << block->getBlockId() << " "; + } + os << std::endl; + } + } + } + return os.str(); +} + std::tuple LRUEvictionPolicy::getFreeBlock(SizeType32 cacheLevel) { for (SizeType32 level = 0; level < kMaxPriority - kMinPriority + 1; level++) @@ -132,8 +161,11 @@ void LRUEvictionPolicy::releaseBlock(BlockPtr block, bool toFront) SizeType32 const cacheLevel = getCacheLevel(block); SizeType32 const id = block->getBlockId(); - // If there are no children, this is a leaf block. Insert into a queue. - auto& q = mFreeQueues[cacheLevel][getPriorityIdx(block->getPriority())]; + auto priority = block->getPriority(); + auto priorityIdx = getPriorityIdx(priority); + TLLM_LOG_DEBUG("%s;%d - LRUEvictionPolicy::releaseBlock :: blockId=%d, cacheLevel=%d, priority=%d, priorityIdx=%d", + __FILE__, __LINE__, id, cacheLevel, priority, priorityIdx); + auto& q = mFreeQueues[cacheLevel][priorityIdx]; if (toFront) { mFreeBlockIterators[id] = q.insert(q.begin(), block); @@ -144,6 +176,8 @@ void LRUEvictionPolicy::releaseBlock(BlockPtr block, bool toFront) } mNumFreeBlocksPerLevel[cacheLevel]++; + TLLM_LOG_DEBUG("Increased mNumFreeBlocksPerLevel[%d] from %d to %d", cacheLevel, + mNumFreeBlocksPerLevel[cacheLevel] - 1, mNumFreeBlocksPerLevel[cacheLevel]); if (block->getDurationMs().has_value() && block->getPriority() != executor::KvCacheRetentionConfig::kDefaultRetentionPriority) @@ -169,22 +203,32 @@ void LRUEvictionPolicy::claimBlock(BlockPtr block, std::optionalgetBlockId(); SizeType32 const cacheLevel = getCacheLevel(block); + auto priorityIdx = getPriorityIdx(getPriorityIdx(block->getPriority())); + TLLM_LOG_DEBUG("%s;%d - LRUEvictionPolicy::claimBlock :: blockId=%d, cacheLevel=%d, priority=%d, priorityIdx=%d", + __FILE__, __LINE__, id, cacheLevel, priority, priorityIdx); + // Detach block from free queue if (mFreeBlockIterators[id] != std::nullopt) { - mFreeQueues[cacheLevel][getPriorityIdx(block->getPriority())].erase(*mFreeBlockIterators[id]); + mFreeQueues[cacheLevel][priorityIdx].erase(*mFreeBlockIterators[id]); mNumFreeBlocksPerLevel[cacheLevel] -= 1; + TLLM_LOG_DEBUG("Decreased mNumFreeBlocksPerLevel[%d] from %d to %d", cacheLevel, + mNumFreeBlocksPerLevel[cacheLevel] + 1, mNumFreeBlocksPerLevel[cacheLevel]); + mFreeBlockIterators[id] = std::nullopt; } - mFreeBlockIterators[id] = std::nullopt; - + // Explicitly set priority, if provided if (priority.has_value()) { block->setPriority(*priority); } + // Detach block from expiring heap (processing of time limited retention priority) mExpiringBlockHeap.erase(block); - block->setDurationMs(durationMs); + if (durationMs.has_value()) + { + block->setDurationMs(durationMs); + } } std::chrono::steady_clock::time_point::duration LRUEvictionPolicy::getTime() const diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index b0b7b494fa6..8168251ca10 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -36,6 +36,8 @@ #include #include #include +#include +#include #include namespace tc = tensorrt_llm::common; @@ -249,18 +251,68 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no return seed; } -KVCacheBlock::KVCacheBlock(IdType blockId, tk::KVCacheIndex blockIdx) +//! \brief Print blockKey to output stream. Intended for debugging. +std::ostream& operator<<(std::ostream& out, BlockKey const& blockKey) +{ + // Note: << operator requires friend declaration if it will print any protected or private members. + // BlockKey is a struct, not a class, hence it has only public members. + bool firstIteration = true; + for (auto uniqueToken : blockKey.uniqueTokens) + { + if (firstIteration) + { + firstIteration = false; + out << "["; + } + else + { + out << " "; + } + if (blockKey.usesExtraIds) + { + out << "(" << uniqueToken.tokenId << "," << uniqueToken.tokenExtraId << ")"; + } + else + { + out << uniqueToken.tokenId; + } + } + out << "]"; + return out; +} + +//! \brief Print vector of BlockKey to output stream. Intended for debugging. +std::ostream& operator<<(std::ostream& out, std::vector const& blockKeys) +{ + bool firstIteration = true; + for (auto const& blockKey : blockKeys) + { + if (firstIteration) + { + firstIteration = false; + } + else + { + out << ", "; + } + out << blockKey; + } + return out; +} + +KVCacheBlock::KVCacheBlock(IdType blockId, tk::KVCacheIndex blockIdx, SizeType32 windowSize) : mBlockId(blockId) , mMemoryPoolBlockIndex{blockIdx} , mRefCount(0) , mSchedulingRefCount(0) - , mPrevBlock(nullptr) , mFreeBlockIterator(std::nullopt) , mIsFull{false} , mPriority{executor::KvCacheRetentionConfig::kDefaultRetentionPriority} , mDurationMs{std::nullopt} , mExpirationTime{std::nullopt} , mHash{0} + , mLookupNode{nullptr} + , mWindowSize{windowSize} { } @@ -274,11 +326,6 @@ KVCacheBlock::IdType KVCacheBlock::getBlockId() const return mBlockId; } -NextBlockMap KVCacheBlock::getNextBlocks() const -{ - return mNextBlocks; -} - tk::KVCacheIndex::UnderlyingType KVCacheBlock::getMemoryPoolBlockIndex() const { return mMemoryPoolBlockIndex.get(); @@ -320,7 +367,7 @@ bool KVCacheBlock::hasRefs() const bool KVCacheBlock::isShared() const { // block is considered shared if ready for reuse - return mRefCount > 1 || mPrevBlock != nullptr; + return mRefCount > 1 || mLookupNode != nullptr; } bool KVCacheBlock::hasSchedulingRefs() const @@ -334,7 +381,7 @@ void KVCacheBlock::setBlockKey(BlockKey const& blockKey, bool isFull) mIsFull = isFull; } -BlockKey KVCacheBlock::getBlockKey() +BlockKey KVCacheBlock::getBlockKey() const { return mBlockKey; } @@ -369,117 +416,734 @@ std::optional KVCacheBlock::get return mExpirationTime; } -void KVCacheBlock::setHash(size_t hash) +void KVCacheBlock::setHash(size_t hash) +{ + mHash = hash; +} + +void KVCacheBlock::setHash() +{ + mHash = BlockKeyHasher()(mBlockKey, mPrevBlockInSeq ? mPrevBlockInSeq->getHash() : 0); +} + +size_t KVCacheBlock::getHash() const +{ + return mHash; +} + +VecUniqueTokens const& KVCacheBlock::getUniqueTokens() const +{ + return mBlockKey.uniqueTokens; +} + +BlockPtr KVCacheBlock::getPrevBlock() const +{ + if (mLookupNode != nullptr && mLookupNode->getPrevNode() != nullptr) + { + return mLookupNode->getPrevNode()->getBlock(mWindowSize); + } + return nullptr; +} + +BlockPtr const& KVCacheBlock::getPrevBlockInSeq() const +{ + return mPrevBlockInSeq; +} + +void KVCacheBlock::setPrevBlockInSeq(BlockPtr prevBlock) +{ + mPrevBlockInSeq = std::move(prevBlock); +} + +bool KVCacheBlock::isFull() const +{ + return mIsFull; +} + +void KVCacheBlock::attachToLookupNode(LookupNodePtr lookupNode, BlockPtr block) +{ + TLLM_CHECK_WITH_INFO(lookupNode != nullptr && block != nullptr, "lookupNode and block arguments cannot be nullptr"); + TLLM_CHECK_WITH_INFO(getBlockId() == block->getBlockId(), "blockIds differ"); + mLookupNode = lookupNode; + lookupNode->setBlock(mWindowSize, block); + mBlockKey = lookupNode->getBlockKey(); + mIsFull = lookupNode->isFull(); +} + +void KVCacheBlock::detachFromLookupNode() +{ + if (mLookupNode != nullptr) + { + mLookupNode->setBlock(mWindowSize, nullptr); + mLookupNode->deleteNodeIfPossible(); + } + mLookupNode = nullptr; + mBlockKey = BlockKey(); + mIsFull = false; +} + +LookupNodePtr KVCacheBlock::getLookupNode() const +{ + return mLookupNode; +} + +KVCachePromptLookup::KVCachePromptLookup(CacheType cacheType, SizeType32 tokensPerBlock) + : mRoot(std::make_shared(BlockKey(), false, true)) + , mCacheType(cacheType) + , mTokensPerBlock(tokensPerBlock) +{ +} + +std::string KVCachePromptLookup::printPrompt(LlmRequest const& llmRequest) const +{ + std::stringstream out; + auto constexpr beamIdx = 0; + auto const& uniqueTokens = (mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY) + ? llmRequest.getUniqueTokens(beamIdx) + : *(llmRequest.getEncoderUniqueTokens().value()); + bool firstIteration = true; + for (auto token : uniqueTokens) + { + if (firstIteration) + { + firstIteration = false; + } + else + { + out << " "; + } + out << token.tokenId; + } + return out.str(); +} + +void KVCachePromptLookupNode::printSearchTree(std::ostream& os, SizeType32 indent) const +{ + std::stringstream ss; + for (int i = 0; i < indent; ++i) + { + ss << " "; + } + auto indentStr = ss.str(); + os << getBlockKey() << "(" << getNextNodes().size() << " children)" << std::endl; + for (auto const& [key, child] : getNextNodes()) + { + os << indentStr << "+ "; + child->printSearchTree(os, indent + 2); + } +} + +std::string KVCachePromptLookup::printSearchTree() const +{ + std::stringstream os; + os << "KVCachePromptLookup cache:" << std::endl; + for (auto const& [key, child] : mRoot->getNextNodes()) + { + child->printSearchTree(os, 0); + } + return os.str(); +} + +std::ostream& operator<<(std::ostream& out, LookupResult const& match) +{ + bool firstIteration = true; + for (auto const& nodeInfo : match) + { + auto const [partialMatch, nuMatched, matchedNode] = nodeInfo; + if (firstIteration) + { + firstIteration = false; + } + else + { + out << "|"; + } + if (matchedNode != nullptr) + { + out << matchedNode->getBlockKey(); + } + else + { + out << "nil"; + } + } + return out; +} + +std::ostream& operator<<(std::ostream& out, LookupResults const& matches) +{ + bool firstIteration = true; + for (auto const& match : matches) + { + if (firstIteration) + { + firstIteration = false; + } + else + { + out << ", "; + } + out << "[" << match << "]"; + } + return out; +} + +std::ostream& operator<<(std::ostream& out, std::tuple const& match) +{ + [[maybe_unused]] auto const [partialMatch, nuMatched, matchedNode, matchedBlock] = match; + if (matchedNode != nullptr) + { + out << matchedNode->getBlockKey(); + } + else + { + out << "nil"; + } + return out; +} + +std::ostream& operator<<( + std::ostream& out, std::vector> const& matches) +{ + bool firstIteration = true; + for (auto const& match : matches) + { + if (firstIteration) + { + firstIteration = false; + } + else + { + out << ", "; + } + out << "[" << match << "]"; + } + return out; +} + +std::ostream& operator<<(std::ostream& out, + std::unordered_map>> const& matches) +{ + for (auto [windowSize, singleWindowMatches] : matches) + { + out << "windowSize=" << windowSize << " :: " << singleWindowMatches << std::endl; + } + return out; +} + +std::vector KVCachePromptLookup::getBlockKeys( + LlmRequest const& llmRequest, SizeType32 inputLength, bool allowPartiallyFilledBlock) const +{ + auto constexpr beamIdx = 0; + auto const& uniqueTokens = (mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY) + ? llmRequest.getUniqueTokens(beamIdx) + : *(llmRequest.getEncoderUniqueTokens().value()); + auto usefulInputLength + = (inputLength < 0) ? static_cast(uniqueTokens.size()) + inputLength : inputLength; + usefulInputLength = std::max(0, usefulInputLength); + + // Ignore last token because it can't be recovered + auto blockedUniqueTokens = chopVectorIntoBlocks( + uniqueTokens, usefulInputLength, mTokensPerBlock, allowPartiallyFilledBlock); + auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); + + return blockKeys; +} + +//! \brief Find last valid block for prefix given in blockKey. +//! \details Seems like blockKey contains a full prefix, not just key for an individual block. This is very hacky and +//! will lead to hilarious bugs. \details Since all KV cache manager bugs eventually gets reassigned to me, THIS MUST BE +//! FIXED AT ALL COSTS. \details We should introduce a new data type for full prefix keys (maybe named FullPrefixKey?). +//! \details Alternatively, we can introduce a new field that tracks whether BlockKey object contains a full prefix or +//! just key for one block. +std::unordered_map> KVCachePromptLookup::findBlocksInReuseTreeByBlockKey( + std::vector const& windowSizes, BlockKey const& blockKey) const +{ + auto blockedUniqueTokens + = chopVectorIntoBlocks(blockKey.uniqueTokens, blockKey.uniqueTokens.size(), mTokensPerBlock, true); + + // TODO: Can buildBlockKeys(...) replace this paragraph? + std::vector blockKeys; + for (auto const& blockedUniqueTokensList : blockedUniqueTokens) + { + blockKeys.push_back(blockKey); + blockKeys.back().uniqueTokens = blockedUniqueTokensList; + } + + // Return value + std::unordered_map> results; + // Keep track of all window sizes that are still looking for matching blocks + std::unordered_set stillLooking; + for (auto windowSize : windowSizes) + { + stillLooking.insert(windowSize); + } + + auto searchRoot = mRoot; + for (auto const& blockKey : blockKeys) + { + if (searchRoot == nullptr) + { + // Neither exact nor partial matches can be found + // Cancel all window sizes still looking for matches + stillLooking.clear(); + } + else + { + // Consider exact match + LookupNodePtr nextSearchRoot = nullptr; + auto exactMatch = searchRoot->findMatchingNodes(blockKey, false, true); + TLLM_CHECK_WITH_INFO( + exactMatch.size() == 0 || exactMatch.size() == 1, "exactMatch must contain either one node or no node"); + if (exactMatch.size() == 1) + { + [[maybe_unused]] auto const& [dummy1, dummy2, exactMatchingNode] = exactMatch[0]; + for (auto windowSize : stillLooking) + { + auto block = exactMatchingNode->getBlock(windowSize); + if (block != nullptr) + { + results.insert_or_assign(windowSize, block); + } + else + { + // Enforce rule that all blocks must be preceded by only valid blocks (no nullptrs allowed) + stillLooking.erase(windowSize); + } + } + searchRoot = exactMatchingNode; + } + } + if (stillLooking.empty()) + { + // Done with all window sizes + break; + } + } + + return results; +} + +//! \brief Return map of vector of matching blocks for a given prompt (provided by llmRequest). +//! \details Map key is windowSize. Vector matches blocks from start of sequence until no more matches can be found. +// TODO: Current return logic is tailored for full attention. Return logic for SWA should be different. +std::unordered_map>> +KVCachePromptLookup::lookupBlocks(std::map const& windowBlockManagers, + LlmRequest const& llmRequest, SizeType32 inputLength, bool allowPartiallyFilledBlock, bool enablePartialReuse) const +{ + // Return value + std::unordered_map>> results; + // Keep track of all window sizes that are still looking for matching blocks + std::unordered_set stillLooking; + for ([[maybe_unused]] auto const& [windowSize, dummy1] : windowBlockManagers) + { + stillLooking.insert(windowSize); + results.emplace( + std::make_pair(windowSize, std::vector>())); + } + + // Search tree + auto blockKeys = getBlockKeys(llmRequest, inputLength, allowPartiallyFilledBlock); + auto searchRoot = mRoot; + for (auto const& blockKey : blockKeys) + { + if (searchRoot == nullptr) + { + // Neither exact nor partial matches can be found + // Cancel all window sizes still looking for matches + stillLooking.clear(); + } + else + { + // Consider exact match + LookupNodePtr nextSearchRoot = nullptr; + std::unordered_set needPartials; + auto exactMatch = searchRoot->findMatchingNodes(blockKey, false, true); + TLLM_CHECK_WITH_INFO( + exactMatch.size() == 0 || exactMatch.size() == 1, "exactMatch must contain either one node or no node"); + if (exactMatch.size() == 1) + { + // found exact matching node + [[maybe_unused]] auto const& [dummy1, dummy2, exactMatchingNode] = exactMatch[0]; + nextSearchRoot = exactMatchingNode; + for (auto windowSize : stillLooking) + { + auto block = exactMatchingNode->getBlock(windowSize); + if (block != nullptr) + { + // found exact matching block + auto& winres = results[windowSize]; + // TODO: verify these outputs + winres.emplace_back(std::make_tuple(!block->isFull(), 0, block, exactMatchingNode)); + } + else + { + // did not find exact match + if (enablePartialReuse) + { + // look for partial match + needPartials.insert(windowSize); + } + else + { + // partial match disabled, cancel + stillLooking.erase(windowSize); + } + } + } + } + else + { + // did not find exact matching node + if (enablePartialReuse) + { + for (auto windowSize : stillLooking) + { + needPartials.insert(windowSize); + } + } + else + { + stillLooking.clear(); + } + } + + // Consider partial match + if (!needPartials.empty()) + { + // Note: Returns partial matches sorted in descending order on num matched tokens. + auto partialMatches = searchRoot->findMatchingNodes(blockKey, true, true); + for (auto windowSize : needPartials) + { + for (auto [partialMatch, numMatched, node] : partialMatches) + { + auto block = node->getBlock(windowSize); + if (block != nullptr) + { + // found partial match + auto& winres = results[windowSize]; + winres.emplace_back(std::make_tuple(true, numMatched, block, node)); + break; + } + } + stillLooking.erase(windowSize); + } + } + + // Advance searchRoot to exact matching node if found, otherwise set to nullptr + searchRoot = nextSearchRoot; + } + + if (stillLooking.empty()) + break; + } + + return results; +} + +//! \brief Get lookup nodes for prompt given by llmRequest. +//! \details Nodes are created if necessary. Returned nodes should be used to store blocks in search tree. For lookup of +//! reusable blocks, please use lookupBlocks instead. +LookupResults KVCachePromptLookup::lookup( + LlmRequest const& llmRequest, SizeType32 inputLength, bool allowPartiallyFilledBlock) +{ + auto blockKeys = getBlockKeys(llmRequest, inputLength, allowPartiallyFilledBlock); + + auto constexpr enablePartialReuse = false; + auto constexpr create = true; + + LookupResults results; + auto searchRoot = mRoot; + for (auto const& blockKey : blockKeys) + { + auto matches = searchRoot != nullptr ? searchRoot->findMatchingNodes(blockKey, enablePartialReuse, false) + : LookupResult(); + if (create && matches.empty()) + { + // No match, create blank prompt node + bool isFull = static_cast(blockKey.uniqueTokens.size()) == mTokensPerBlock; + auto newNode = std::make_shared(blockKey, isFull); + newNode->setPrevNode(searchRoot); + searchRoot->addNextNode(blockKey, newNode); + matches.emplace_back( + std::make_tuple(isFull, static_cast(blockKey.uniqueTokens.size()), newNode)); + searchRoot = newNode; + } + else if (matches.empty()) + { + // No match + // Stop search here + searchRoot = nullptr; + } + else + { + // Stop search if first node is a partial match + // TODO: This may not be correct, findMatchingNodes can return both a full match and several partial matches + // at once. + auto const& [partialMatch, _, matchingNode] = matches[0]; + searchRoot = partialMatch ? nullptr : matchingNode; + } + results.emplace_back(std::move(matches)); + } + return results; +} + +// Return BlockKey of first new context block for all window sizes. If no new context block was found for a window size, +// return value will not have an entry for that window size. +std::unordered_map KVCachePromptLookup::findNewContextBlock(LlmRequest const& llmRequest, + SizeType32 inputLength, bool allowPartiallyFilledBlock, std::vector const& windowSizes) const +{ + auto blockKeys = getBlockKeys(llmRequest, inputLength, allowPartiallyFilledBlock); + + // New context block is the block key of the first block that isn't found in search structure. + std::unordered_set stillLooking; + for (auto const windowSize : windowSizes) + { + stillLooking.insert(windowSize); + } + std::unordered_map results; + auto searchRoot = mRoot; + BlockKey prevBlockKey; + for (auto const& blockKey : blockKeys) + { + auto matches = searchRoot != nullptr ? searchRoot->findMatchingNodes(blockKey, false, true) : LookupResult(); + [[maybe_unused]] auto const& [dummy1, dummy2, matchingNode] + = matches.empty() ? std::make_tuple(false, 0, nullptr) : matches[0]; + for (auto const windowSize : windowSizes) + { + if (stillLooking.count(windowSize) + && (matchingNode == nullptr || matchingNode->getBlock(windowSize) == nullptr)) + { + // Found new context block for current window size + results[windowSize] = blockKey; + stillLooking.erase(windowSize); + } + } + searchRoot = matchingNode; + if (stillLooking.size() == 0) + { + // Not looking for any more new context block + break; + } + } + + return results; +} + +KVCachePromptLookupNode::KVCachePromptLookupNode(BlockKey const& blockKey, bool isFull, bool isRoot) + : mBlockKey{blockKey} + , mIsFull{isFull} + , mIsRoot{isRoot} + , mPrevNode{nullptr} + , mNextNodes{} + , mBlocks{} +{ +} + +void KVCachePromptLookupNode::setBlockKey(BlockKey const& blockKey, bool isFull) +{ + mBlockKey = blockKey; + mIsFull = isFull; +} + +BlockKey KVCachePromptLookupNode::getBlockKey() const { - mHash = hash; + return mBlockKey; } -void KVCacheBlock::setHash() +VecUniqueTokens const& KVCachePromptLookupNode::getUniqueTokens() const { - mHash = BlockKeyHasher()(mBlockKey, mPrevBlockInSeq ? mPrevBlockInSeq->getHash() : 0); + return mBlockKey.uniqueTokens; } -size_t KVCacheBlock::getHash() const +bool KVCachePromptLookupNode::isRoot() const { - return mHash; + return mIsRoot; } -VecUniqueTokens const& KVCacheBlock::getUniqueTokens() const +LookupNodePtr const& KVCachePromptLookupNode::getPrevNode() const { - return mBlockKey.uniqueTokens; + return mPrevNode; } -BlockPtr const& KVCacheBlock::getPrevBlock() const +void KVCachePromptLookupNode::setPrevNode(LookupNodePtr prevNode) { - return mPrevBlock; + mPrevNode = prevNode; } -void KVCacheBlock::setPrevBlock(BlockPtr prevBlock) +NextNodeMap KVCachePromptLookupNode::getNextNodes() const { - mPrevBlock = std::move(prevBlock); + return mNextNodes; } -BlockPtr const& KVCacheBlock::getPrevBlockInSeq() const +void KVCachePromptLookupNode::addNextNode(BlockKey const& blockKey, LookupNodePtr node) { - return mPrevBlockInSeq; + mNextNodes[blockKey] = node; } -void KVCacheBlock::setPrevBlockInSeq(BlockPtr prevBlock) +void KVCachePromptLookupNode::removeNextNode(BlockKey const& blockKey) { - mPrevBlockInSeq = std::move(prevBlock); + mNextNodes.erase(blockKey); } -void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block) +LookupResult KVCachePromptLookupNode::findMatchingNodes( + BlockKey const& blockKey, bool enablePartialReuse, bool ignoreNodesWithoutBlocks) const { - if (mNextBlocks.find(blockKey) == mNextBlocks.end()) + LookupResult result; + if (blockKey.uniqueTokens.size() == 0 || mNextNodes.size() == 0) + { + // invalid search key or no searchable nodes + return result; + } + auto itr = mNextNodes.find(blockKey); + // TODO: Skipping nodes that have no blocks only works when looking up blocks. + // When looking up nodes and possibly creating new nodes, we must ignore whether node has blocks or not. + if (itr != mNextNodes.end() && (!ignoreNodesWithoutBlocks || itr->second->hasBlocks())) + { + // found exact match + auto node = itr->second; + result.emplace_back( + std::make_tuple(!node->isFull(), static_cast(blockKey.uniqueTokens.size()), node)); + return result; + } + if (enablePartialReuse) { - mNextBlocks[blockKey] = std::move(block); + // find all nodes with at least one matching token + for (auto const& [key, node] : mNextNodes) + { + SizeType32 numMatched = key.partialMatch(blockKey); + if (numMatched > 0 && node != nullptr && node->hasBlocks()) + { + result.emplace_back(std::make_tuple(true, numMatched, node)); + } + } + // sort partial matches in ascending order + std::sort(result.begin(), result.end(), + [](std::tuple const& a, + std::tuple const& b) + { + [[maybe_unused]] auto [dummy1, numMatchedA, dummy2] = a; + [[maybe_unused]] auto [dummy3, numMatchedB, dumm4] = b; + return numMatchedA > numMatchedB; + }); + return result; } + return result; } -std::tuple KVCacheBlock::findMatchingBlock( - BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const +/* +std::tuple KVCachePromptLookupNode::findMatchingNode(BlockKey const& blockKey, bool +enablePartialReuse) const { - if (blockKey.uniqueTokens.size() == 0 || mNextBlocks.size() == 0) + if (blockKey.uniqueTokens.size() == 0 || mNextNodes.size() == 0) { return {false, 0, nullptr}; } - auto itr = mNextBlocks.find(blockKey); - if (itr == mNextBlocks.end()) + auto itr = mNextNodes.find(blockKey); + if (itr == mNextNodes.end()) { if (enablePartialReuse) { SizeType32 bestNumMatched{0}; - BlockPtr bestBlock{nullptr}; - for (auto const& [key, block] : mNextBlocks) + LookupNodePtr bestNode{nullptr}; + for (auto const& [key, node] : mNextNodes) { - if (copyOnPartialReuse || (!block->hasRefs() && block->isLeaf())) + SizeType32 numMatched = key.partialMatch(blockKey); + if (numMatched > bestNumMatched) { - SizeType32 numMatched = key.partialMatch(blockKey); - if (numMatched > bestNumMatched) - { - bestNumMatched = numMatched; - bestBlock = block; - } + bestNumMatched = numMatched; + bestNode = node; } } if (bestNumMatched > 0) { - return {true, bestNumMatched, bestBlock}; + return {true, bestNumMatched, bestNode}; } } return {false, 0, nullptr}; } - auto block = itr->second; - return {!block->isFull(), static_cast(blockKey.uniqueTokens.size()), block}; + auto node = itr->second; + return {!node->isFull(), static_cast(blockKey.uniqueTokens.size()), node}; } +*/ -void KVCacheBlock::freeLeafBlock() +void KVCachePromptLookupNode::setBlock(SizeType32 windowSize, BlockPtr block) { - // assure that this is a leaf block - TLLM_CHECK(isLeaf()); + if (block == nullptr) + { + auto blockItr = mBlocks.find(windowSize); + if (blockItr != mBlocks.end()) + { + auto block = blockItr->second; + mBlocks.erase(blockItr); + } + } + else + { + mBlocks.emplace(std::make_pair(windowSize, block)); + } +} - // free from previous block - if (mPrevBlock != nullptr) +BlockPtr KVCachePromptLookupNode::getBlock(SizeType32 windowSize) const +{ + auto blockItr = mBlocks.find(windowSize); + if (blockItr != mBlocks.end()) + { + return blockItr->second; + } + else { - mPrevBlock->removeNextBlock(mBlockKey); - mPrevBlock = nullptr; + return nullptr; } } -void KVCacheBlock::removeNextBlock(BlockKey const& blockKey) +bool KVCachePromptLookupNode::hasBlocks() const { - mNextBlocks.erase(blockKey); + return !mBlocks.empty(); } -bool KVCacheBlock::isFull() const +bool KVCachePromptLookupNode::isFull() const { return mIsFull; } -bool KVCacheBlock::isLeaf() const +bool KVCachePromptLookupNode::isLeaf() const +{ + return mNextNodes.empty(); +} + +bool KVCachePromptLookupNode::canBeDeleted() const +{ + return !isRoot() && isLeaf() && !hasBlocks(); +} + +std::vector KVCachePromptLookupNode::getFullBlockKey() const +{ + std::list keys; + keys.emplace_back(getBlockKey()); + for (auto prevNode = getPrevNode(); prevNode != nullptr && !prevNode->isRoot(); prevNode = prevNode->getPrevNode()) + { + keys.emplace_back(prevNode->getBlockKey()); + } + std::vector retval; + if (!keys.empty()) + { + retval.reserve(keys.size()); + retval.insert(retval.end(), keys.rbegin(), keys.rend()); + } + return retval; +} + +void KVCachePromptLookupNode::deleteNodeIfPossible() { - return mNextBlocks.empty(); + if (canBeDeleted()) + { + TLLM_LOG_DEBUG("Deleted node " + streamPrint(getFullBlockKey())); + auto prevNode = mPrevNode; + mPrevNode = nullptr; + prevNode->removeNextNode(getBlockKey()); + prevNode->deleteNodeIfPossible(); + } } // This function calculates the number of block a layer should have, given @@ -551,6 +1215,8 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si , mEventManager{std::move(eventManager)} , mStream{stream} , mCacheType{cacheType} + , mEnablePartialReuse{enablePartialReuse} + , mLookup{std::make_shared(cacheType, tokensPerBlock)} { if (agentConfig.has_value()) mLoopbackAgent = makeLoopbackAgent("nixl", &agentConfig.value()); @@ -654,7 +1320,6 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind , mSchedulingNumFreeBlocks{0} , mTokensPerBlock{tokensPerBlock} , mIsSWA{isSWA} - , mCachedBlocksRoot{std::make_shared(KVCacheBlock::kCachedBlocksRootId, tk::KVCacheIndex{0})} , mCacheType{cacheType} , mEventManager(std::move(eventManager)) , mLoopbackAgent{loopbackAgent} @@ -725,13 +1390,17 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind mAllBlocksById.reserve(blocksInPrimaryPool + blocksInSecondaryPool); for (KVCacheBlock::IdType blockId = 0; blockId < blocksInPrimaryPool; ++blockId) { - mAllBlocksById.emplace_back(std::make_shared(blockId, tk::KVCacheIndex{blockId, false})); + mAllBlocksById.emplace_back( + std::make_shared(blockId, tk::KVCacheIndex{blockId, false}, windowSize)); } for (KVCacheBlock::IdType blockId = 0; blockId < blocksInSecondaryPool; ++blockId) { mAllBlocksById.emplace_back( - std::make_shared(blocksInPrimaryPool + blockId, tk::KVCacheIndex{blockId, true})); + std::make_shared(blocksInPrimaryPool + blockId, tk::KVCacheIndex{blockId, true}, windowSize)); } + // TODO: This duplicates the information stored in GenerationRequest::mCacheBlockIndices. + // Why do we need this? What happens when these two no longer are in sync? + // Should we get rid of GenerationRequest::mCacheBlockIndices? mAllocatedBlocksPerSeq.reserve(maxNumSequences); mEvictionPolicy = std::make_shared(); @@ -739,6 +1408,8 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind mAllBlocksById, {blocksInPrimaryPool, blocksInSecondaryPool}, secondaryOffloadMinPriority); if (mEventManager) { + TLLM_LOG_DEBUG("%s;%d - mEventManager->enqueueCreatedEvent(blocksInPrimaryPool=%d, blocksInSecondaryPool=%d) ", + __FILE__, __LINE__, blocksInPrimaryPool, blocksInSecondaryPool); mEventManager->enqueueCreatedEvent({blocksInPrimaryPool, blocksInSecondaryPool}, mWindowSize); } } @@ -776,22 +1447,15 @@ bool WindowBlockManager::verifyQueueIntegrity() void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest) { - constexpr int beamIdx = 0; // no need to consider more than one beam for input tokens + std::lock_guard lock(mCachedBlocksRootMutex); + // Lookup prompt nodes once for all window block managers + auto constexpr beamIdx = 0; + auto matchedPromptNodes = mLookup->lookup(llmRequest, -1, false); for (auto const& [windowSize, _] : mWindowBlockManagers) { - if (mWindowBlockManagers.at(windowSize).isSWA()) - { - // SWA cannot store new blocks on the fly because the block stored - // may go OOW and be reused by another sequence. - continue; - } auto cacheBlockIds = sequence.getCacheBlockIds(windowSize); - auto const& uniqueTokens = llmRequest.getUniqueTokens(beamIdx); - - auto blockedUniqueTokens - = chopVectorIntoBlocks(uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), false); - auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); - (void) mWindowBlockManagers.at(windowSize).storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); + // TODO: Should we return these values? + [[maybe_unused]] auto dummy1 = storeBlocks(matchedPromptNodes, cacheBlockIds[beamIdx], windowSize); } } @@ -937,6 +1601,8 @@ void BlockManager::startScheduling() void WindowBlockManager::startScheduling() { mSchedulingNumFreeBlocks = mEvictionPolicy->getNumFreeBlocks(kPrimaryLevel); + // TODO: For SWA, mAllocatedBlocksPerSeq contains pointers to all the blocks allocated to a particular sequence, + // including blocks that have been detached because they went OOW. Is it an issue that we include those here? for (auto& [requestId, slotAllocatedBlocks] : mAllocatedBlocksPerSeq) { for (auto& allocatedBlock : slotAllocatedBlocks) @@ -946,36 +1612,14 @@ void WindowBlockManager::startScheduling() } } -void WindowBlockManager::freeLeafBlock(BlockPtr const& block) -{ - // The eviction policy needs blocks to still be linked to their old parents when they're reclaimed. - // This is so it can check if the parent should be queued for eviction. - block->freeLeafBlock(); -} - -void WindowBlockManager::freeChildren(BlockPtr const& block) -{ - // Free all descendants of block - for (auto const& p : block->getNextBlocks()) - { - auto childBlock = p.second; - freeChildren(childBlock); - } - - // Free block - if (mEventManager && blockInRadixTree(block)) - { - mEventManager->enqueueRemovedEvent(block, mWindowSize); - } - freeLeafBlock(block); -} - -BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor::RetentionPriority priority, +BlockPtr WindowBlockManager::getFreeBlock(executor::RetentionPriority priority, std::optional durationMs, executor::KvCacheTransferMode mode, std::string const& directory) { // eviction policy get free primary block auto [block, canOffload] = mEvictionPolicy->getFreeBlock(kPrimaryLevel); + // remove block from free queue and set priority + mEvictionPolicy->claimBlock(block, priority, durationMs); if (block->getUniqueTokens().empty()) { ++mAllocNewBlocks; @@ -989,58 +1633,39 @@ BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor: if (!block->getUniqueTokens().empty() && canOffload && mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel) > 0 && mOnboardBlocks) { - // Offload block in primary memory before repurposing + // Offload block to primary memory before repurposing auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel)); + // Remove block from secondary free queue + mEvictionPolicy->claimBlock(offloadBlock); + TLLM_LOG_DEBUG("Offloading block %d to %d", block->getBlockId(), offloadBlock->getBlockId()); mTransferManager->offload(block, offloadBlock, mPools, 0, mode, directory); // swap linear block offsets (i.e. make block the offload block) block->swapMemoryPoolBlockOffset(offloadBlock); if (mEventManager && blockInRadixTree(block)) { + TLLM_LOG_DEBUG("%s;%d - mEventManager->enqueueUpdatedEvent(block=%d) Primary->Secondary", __FILE__, + __LINE__, block->getBlockId()); mEventManager->enqueueUpdatedEvent( tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel), mWindowSize); } - // Update the block as a secondary block (maintaining its priority) - mEvictionPolicy->claimBlock(block); - // Release the block into secondary block queue + + // append offload block to mFreeSecondaryBlocks queue mEvictionPolicy->releaseBlock(block); - // We have the offloaded block as the block to use now. block = offloadBlock; } - - // Removes children of the block from the search tree - freeChildren(block); - // Claim the block in primary block queue - mEvictionPolicy->claimBlock(block, priority, durationMs); - - // Deal with invalidating block save for reuse for the sequence - if (mBlockToSequence.count(block->getBlockId()) > 0) + if (mEventManager && blockInRadixTree(block)) { - auto const& originalOwnerSequenceId = mBlockToSequence[block->getBlockId()]; - if (mIsValidStoreForReuseSequence.count(originalOwnerSequenceId) > 0 - && sequence.getRequestId() != originalOwnerSequenceId) - { - TLLM_LOG_DEBUG("%s::getFreeBlock - Block %d was originally held but released from sequence %d", - mLogPrefix.c_str(), block->getBlockId(), originalOwnerSequenceId); - if (mIsValidStoreForReuseSequence[originalOwnerSequenceId]) - { - TLLM_LOG_DEBUG("%s::getFreeBlock - Invalidate store block for reuse for sequence %d", - mLogPrefix.c_str(), originalOwnerSequenceId); - } - else - { - TLLM_LOG_DEBUG("%s::getFreeBlock - Store block for reuse for sequence %d is already invalid", - mLogPrefix.c_str(), originalOwnerSequenceId); - } - mIsValidStoreForReuseSequence[originalOwnerSequenceId] = false; - } + TLLM_LOG_DEBUG("%s;%d - mEventManager->enqueueRemovedEvent(block=%d)", __FILE__, __LINE__, block->getBlockId()); + mEventManager->enqueueRemovedEvent(block, mWindowSize); + } + // Detach block from search structure + if (block->getLookupNode() != nullptr) + { + TLLM_LOG_DEBUG("Evicting block %d", block->getBlockId()); + block->detachFromLookupNode(); } - - // Record which sequence is using the block - mBlockToSequence[block->getBlockId()] = sequence.getRequestId(); - TLLM_LOG_DEBUG("%s::getFreeBlock - Block %d is now acquired by sequence %d", mLogPrefix.c_str(), - block->getBlockId(), sequence.getRequestId()); return block; } @@ -1084,14 +1709,16 @@ void WindowBlockManager::onboardBlock(GenerationRequest& sequence, BlockPtr cons { if (mOnboardBlocks && !offloadBlock->isPrimary()) { - auto block = getFreeBlock( - sequence, executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt, mode, directory); + auto block + = getFreeBlock(executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt, mode, directory); mTransferManager->onboard(offloadBlock, block, mPools, 0, mode, directory); // swap linear block offsets (i.e. make block the offload block and vice versa) offloadBlock->swapMemoryPoolBlockOffset(block); if (mEventManager) { + TLLM_LOG_DEBUG("%s;%d - mEventManager->enqueueUpdatedEvent(block=%d) Secondary->Primary", __FILE__, + __LINE__, offloadBlock->getBlockId()); mEventManager->enqueueUpdatedEvent( tle::KVCacheUpdatedData(offloadBlock->getHash()).cacheLevelUpdated(kSecondaryLevel, kPrimaryLevel), mWindowSize); @@ -1128,6 +1755,8 @@ void WindowBlockManager::offloadBlock( if (mEventManager && blockInRadixTree(block)) { + TLLM_LOG_DEBUG("%s;%d - mEventManager->enqueueUpdatedEvent(block=%d) Primary->Secondary", __FILE__, + __LINE__, block->getBlockId()); mEventManager->enqueueUpdatedEvent( tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel), mWindowSize); @@ -1140,97 +1769,96 @@ void WindowBlockManager::offloadBlock( [[nodiscard]] std::optional BlockManager::findNewContextBlock( VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const { + std::lock_guard lock(mCachedBlocksRootMutex); TLLM_CHECK_WITH_INFO( !isVariableWindow(), "The optimization of delaying requests won't work for variable window attention"); - auto const& onlyManager = mWindowBlockManagers.cbegin()->second; - return onlyManager.findNewContextBlock(uniqueTokens, llmRequest); -} + // TODO: Deciding whether a request should be delayed should be done for each window block manager, + // aggregate decision is true if any of them returns true + // For now we replicate old behavior -std::optional WindowBlockManager::findNewContextBlock( - VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const -{ - auto blockedUniqueTokens - = chopVectorIntoBlocks(uniqueTokens, uniqueTokens.size(), mTokensPerBlock, false); - auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); - BlockKey ret; - ret.loraTaskId = llmRequest.getLoraTaskId(); - auto searchRoot = mCachedBlocksRoot; - for (auto const& blockKey : blockKeys) + // Copy window sizes we are interested in (just first one for now) + std::vector windowSizes; + windowSizes.reserve(1); + auto firstWindowSize = mWindowBlockManagers.cbegin()->first; + windowSizes.emplace_back(firstWindowSize); + + // Get blockKey for first window size + auto newBlockKeys = mLookup->findNewContextBlock(llmRequest, -1, true, windowSizes); + + // If newBlockKeys does not have an entry for firstWindowSize, it means no new context block was found + if (newBlockKeys.count(firstWindowSize)) { - ret.uniqueTokens.insert(ret.uniqueTokens.end(), blockKey.uniqueTokens.begin(), blockKey.uniqueTokens.end()); - auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr - ? searchRoot->findMatchingBlock(blockKey, false, false) - : std::make_tuple(false, 0, nullptr); - if (matchingBlock == nullptr) - { - return ret; - } - searchRoot = std::move(matchingBlock); + return newBlockKeys.at(firstWindowSize); + } + else + { + return std::nullopt; } - return std::nullopt; } bool WindowBlockManager::blockInRadixTree(BlockPtr const& block) { - return !block->getUniqueTokens().empty() && block->getPrevBlock() != nullptr; + return !block->getUniqueTokens().empty() && block->getLookupNode() != nullptr; } -std::shared_ptr WindowBlockManager::findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey) +std::shared_ptr BlockManager::findBlocksInReuseTreeByBlockKey( + BlockKey const& blockKey, SizeType32 windowSize) const { std::lock_guard lock(mCachedBlocksRootMutex); - auto blockedUniqueTokens - = chopVectorIntoBlocks(blockKey.uniqueTokens, blockKey.uniqueTokens.size(), mTokensPerBlock, true); - - std::vector blockKeys; - for (auto const& blockedUniqueTokensList : blockedUniqueTokens) - { - blockKeys.push_back(blockKey); - blockKeys.back().uniqueTokens = blockedUniqueTokensList; - } - auto searchRoot = mCachedBlocksRoot; - for (auto const& blockKey : blockKeys) - { - auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr - ? searchRoot->findMatchingBlock(blockKey, true, true) - : std::make_tuple(false, 0, nullptr); - - if (matchingBlock == nullptr) - { - return nullptr; - } + std::vector windowSizes; + windowSizes.reserve(1); + windowSizes.emplace_back(windowSize); + auto lastBlocks = mLookup->findBlocksInReuseTreeByBlockKey(windowSizes, blockKey); + auto itr = lastBlocks.find(windowSize); + TLLM_CHECK_WITH_INFO(itr != lastBlocks.end(), "No block returned for windowSize %d", windowSize); - searchRoot = std::move(matchingBlock); - } - return searchRoot; + return itr->second; } -SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& blockKeys, SizeType32 numContextBlocks, - GenerationRequest& sequence, std::vector const& perBlockRetentions, - executor::KvCacheTransferMode mode, std::string const& directory) +SizeType32 WindowBlockManager::loadOrAllocateBlocks( + std::vector> const& matchedBlocks, + SizeType32 numContextBlocks, GenerationRequest& sequence, + std::vector const& perBlockRetentions, executor::KvCacheTransferMode mode, + std::string const& directory) { - std::lock_guard lock(mCachedBlocksRootMutex); SizeType32 numMatchedTokens{0}; - auto searchRoot = mCachedBlocksRoot; // The last block cannot be shared between beams because it will be written to. // Make sure a unique block is allocated per beam. auto const beamWidth = sequence.getBeamWidth(); SizeType32 numSharedContextBlocks = beamWidth > 1 ? numContextBlocks - 1 : numContextBlocks; - auto blockItr = blockKeys.begin(); + // Claim all reusable blocks to prevent them from accidentally being overwritten by offloading/onboarding logic + for (auto& [partialMatch, numMatched, matchingBlock, matchingNode] : matchedBlocks) + { + if (matchingBlock != nullptr && !matchingBlock->hasRefs()) + { + mEvictionPolicy->claimBlock(matchingBlock); + } + } + + bool validForReuse = true; + auto blockItr = matchedBlocks.begin(); for (int bi = 0; bi < numSharedContextBlocks; ++bi) { - auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr && blockItr != blockKeys.end() - ? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse) - : std::make_tuple(false, 0, nullptr); - if (matchingBlock != nullptr) + auto [partialMatch, numMatched, matchingBlock, matchingNode] = validForReuse && blockItr != matchedBlocks.end() + ? *(blockItr++) + : std::make_tuple(false, 0, nullptr, nullptr); + TLLM_LOG_DEBUG("%s;%d", __FILE__, __LINE__); + // Check if matchingBlock is still valid for reuse. + // It is possible that a matching block has been evicted after the last scan of search tree. + validForReuse = validForReuse && matchingBlock != nullptr && matchingBlock->isValidForReuse(); + if (validForReuse) { KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId(); - numMatchedTokens += numMatched > 0 ? numMatched : blockItr->uniqueTokens.size(); + numMatchedTokens += numMatched > 0 ? numMatched : matchingNode->getUniqueTokens().size(); if (perBlockRetentions[bi].retentionPriority.has_value() && matchingBlock->getPriority() != perBlockRetentions[bi].retentionPriority && mEventManager) { + TLLM_LOG_DEBUG("%s;%d - mEventManager->enqueueUpdatedEvent(block=%d) Priority changed from %d to %d", + __FILE__, __LINE__, matchingBlock->getBlockId(), matchingBlock->getPriority(), + *perBlockRetentions[bi].retentionPriority); mEventManager->enqueueUpdatedEvent( tle::KVCacheUpdatedData(matchingBlock->getHash()) .priorityUpdated(matchingBlock->getPriority(), *perBlockRetentions[bi].retentionPriority), @@ -1238,33 +1866,46 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& } if (partialMatch) { - if (matchingBlock->hasRefs() || !matchingBlock->isLeaf()) + auto partialBlockKey = matchingNode->getBlockKey().clone(numMatched); + if (matchingBlock->hasRefs() || !matchingNode->isLeaf()) { - // Somebody else is using block or it is not a leaf, copy reusable tokens - auto newBlock = getFreeBlock( - sequence, matchingBlock->getPriority(), matchingBlock->getDurationMs(), mode, directory); + // Somebody else is using block or this is full attention and block is not a leaf, copy reusable + // tokens + // TODO: Consider whether non-leaf blocks should be reused instead of copied for SWA layers. + auto newBlock + = getFreeBlock(matchingBlock->getPriority(), matchingBlock->getDurationMs(), mode, directory); mTransferManager->onboard(matchingBlock, newBlock, mPools, numMatched, mode, directory); // TODO: (optional) Send out event - matchingBlock = newBlock; - if (blockItr != blockKeys.end()) + TLLM_LOG_DEBUG( + "%s::loadOrAllocateBlocks - Copied partially filled block %d into %d. Reserved block %d for " + "sequence %lu", + mLogPrefix.c_str(), matchingBlock->getBlockId(), newBlock->getBlockId(), newBlock->getBlockId(), + sequence.getRequestId()); + if (!matchingBlock->hasRefs()) { - matchingBlock->setBlockKey( - *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); + // Release the matching block. + // Transfer manager keeps track of all block copies and make sure that new block + // copies will wait until pending ones have finished. + mEvictionPolicy->releaseBlock(matchingBlock); } - matchingBlock->setHash(); - TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Copied partially filled block %d", mLogPrefix.c_str(), - matchingBlockId); + matchingBlock = newBlock; } else { // Leaf block that nobody is using. Make block private and reuse - freeLeafBlock(matchingBlock); mEvictionPolicy->claimBlock( matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); + // Free block from search structure + matchingBlock->detachFromLookupNode(); + TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Reused partially filled block %d", mLogPrefix.c_str(), matchingBlockId); } - searchRoot = nullptr; // no matching needed for following blocks + // TODO: In theory, there is no way a partial block key can be considered full, so last argument is + // probably always false + matchingBlock->setBlockKey( + partialBlockKey, partialBlockKey.uniqueTokens.size() == static_cast(mTokensPerBlock)); + matchingBlock->setHash(); } else { @@ -1272,7 +1913,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& mEvictionPolicy->claimBlock( matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Matched full block %d", mLogPrefix.c_str(), matchingBlockId); - searchRoot = matchingBlock; } onboardBlock(sequence, matchingBlock, mode, directory); addBlockToAllBeams(matchingBlock, sequence); @@ -1283,26 +1923,25 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& reusedBlockIds.insert(matchingBlockId); ++mReusedUniqueBlocks; } - ++blockItr; } else { // If we haven't set a priority, set it to the default priority level (low) - auto freeBlock = getFreeBlock(sequence, - perBlockRetentions[bi].retentionPriority.value_or( - executor::KvCacheRetentionConfig::kDefaultRetentionPriority), + auto freeBlock = getFreeBlock(perBlockRetentions[bi].retentionPriority.value_or( + executor::KvCacheRetentionConfig::kDefaultRetentionPriority), perBlockRetentions[bi].durationMs, mode, directory); addBlockToAllBeams(freeBlock, sequence); - TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - No match, allocated new block %d for sequence %lu", + TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - No match. Reserved new block %d for sequence %lu", mLogPrefix.c_str(), freeBlock->getBlockId(), sequence.getRequestId()); - searchRoot = nullptr; // no matching needed for following blocks - if (blockItr != blockKeys.end()) + /* + // TODO: Clean this up. Does it need to be done here? Should be done when block is stored. + if (matchingBlock != nullptr) { freeBlock->setBlockKey( - *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); - ++blockItr; + *blockItr, nodeItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); } freeBlock->setHash(); + */ ++mMissedBlocks; } } @@ -1315,31 +1954,33 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { // If we haven't set a priority, set it to the default priority level (low) - auto freeBlock = getFreeBlock(sequence, - perBlockRetentions[bi].retentionPriority.value_or( - executor::KvCacheRetentionConfig::kDefaultRetentionPriority), + auto freeBlock = getFreeBlock(perBlockRetentions[bi].retentionPriority.value_or( + executor::KvCacheRetentionConfig::kDefaultRetentionPriority), perBlockRetentions[bi].durationMs, mode, directory); addBlockToBeam(freeBlock, sequence, beamIdx); - if (blockItr != blockKeys.end()) - { - freeBlock->setBlockKey( - *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); - ++blockItr; - } - freeBlock->setHash(); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d", mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), bi); + // TODO: freeBlock->setBlockKey + setHash } ++mMissedBlocks; - if (blockItr != blockKeys.end()) - { - ++blockItr; - } } return numMatchedTokens; } +void BlockManager::syncTransferManagerWithBufferManager() +{ + for (auto& [_, manager] : mWindowBlockManagers) + { + manager.syncTransferManagerWithBufferManager(); + } +} + +void WindowBlockManager::syncTransferManagerWithBufferManager() +{ + mTransferManager->syncWithBufferManager(); +} + void BlockManager::refreshBlocks() { for (auto& [_, manager] : mWindowBlockManagers) @@ -1359,13 +2000,22 @@ void WindowBlockManager::refreshBlocks() void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize) { - mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest); + std::lock_guard lock(mCachedBlocksRootMutex); + TLLM_LOG_DEBUG("BlockManager::addSequence - prompt = " + mLookup->printPrompt(llmRequest)); + + // Find matching blocks that can be reused + auto matchedBlocks + = mLookup->lookupBlocks(mWindowBlockManagers, llmRequest, inputLength - 1, true, mEnablePartialReuse); + TLLM_LOG_DEBUG("BlockManager::addSequence - blocks = " + streamPrint(matchedBlocks)); + // Find kv cache blocks to reuse for each window manager + mWindowBlockManagers.at(windowSize) + .addSequence(sequence, inputLength, numContextBlocks, llmRequest, matchedBlocks[windowSize]); } // There are two versions of WindowBlockManager::addSequence function. // This is called when block reuse is enabled. -void WindowBlockManager::addSequence( - GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest) +void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, + LlmRequest& llmRequest, std::vector> const& matchedBlocks) { auto const requestId = sequence.getRequestId(); auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); @@ -1387,9 +2037,18 @@ void WindowBlockManager::addSequence( auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); auto config = llmRequest.getKvCacheRetentionConfig(); + if (config.has_value()) + { + TLLM_LOG_DEBUG("%s;%d - KvCacheRetentionConfig = %s", __FILE__, __LINE__, config.value().print().c_str()); + } auto perBlockRetentions = config.value_or(executor::KvCacheRetentionConfig()) .getPerBlockRetentionPriorityDuration(getTokensPerBlock(), inputLength); + for (auto perBlockRetention : perBlockRetentions) + { + TLLM_LOG_DEBUG( + "%s;%d - per Block KvCacheRetentionConfig = %s", __FILE__, __LINE__, perBlockRetention.print().c_str()); + } auto mode = config.value_or(executor::KvCacheRetentionConfig()).getTransferMode(); auto directory = config.value_or(executor::KvCacheRetentionConfig()).getDirectory(); @@ -1404,7 +2063,7 @@ void WindowBlockManager::addSequence( TLLM_CHECK(perBlockRetentions.size() == (size_t) numContextBlocks); auto const prepopulatedPromptLen - = loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions, mode, directory); + = loadOrAllocateBlocks(matchedBlocks, numContextBlocks, sequence, perBlockRetentions, mode, directory); mReusedTokens += static_cast(prepopulatedPromptLen); mTotalInputTokens += static_cast(uniqueTokens.size()); @@ -1449,6 +2108,11 @@ void WindowBlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence) } } +std::string WindowBlockManager::printFreeQueues() const +{ + return mEvictionPolicy->printFreeQueues(); +} + // There are two versions of BlockManager::addSequence function. // This is called when block reuse is disabled. void BlockManager::addSequence( @@ -1522,7 +2186,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm if (shareAmongBeams) { // add same block to all beams - auto block = getFreeBlock(sequence, sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs(), + auto block = getFreeBlock(sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs(), sequence.getTransferMode(), sequence.getDirectory()); for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { @@ -1534,7 +2198,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm // add different block to each beam for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { - auto block = getFreeBlock(sequence, sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs(), + auto block = getFreeBlock(sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs(), sequence.getTransferMode(), sequence.getDirectory()); addBlockToBeam(block, sequence, beamIdx); } @@ -1542,54 +2206,41 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm } std::pair> WindowBlockManager::storeBlocks( - std::vector const& blockKeys, std::vector const& blockIds, bool pinBlocks) + LookupResults const& lookupNodes, std::vector const& blockIds, bool pinBlocks) { SizeType32 numBlocksStoredForReuse = 0; - std::lock_guard lock(mCachedBlocksRootMutex); TLLM_LOG_DEBUG( - "%s::storeBlocks - %zu blockKeys, %zu blockIds", mLogPrefix.c_str(), blockKeys.size(), blockIds.size()); - - auto searchRoot = mCachedBlocksRoot; - bool needMatch = true; + "%s::storeBlocks - %zu lookupNodes, %zu blockIds", mLogPrefix.c_str(), lookupNodes.size(), blockIds.size()); - auto numBlocks = blockKeys.size(); + auto numNodes = lookupNodes.size(); + TLLM_CHECK_WITH_INFO(numNodes <= blockIds.size(), "BlockIds not provided for all lookup nodes"); std::vector storedBlocks; + // TODO: Verify that storedBlocks is supposed to be vector of newly stored blocks std::optional lastStoredId = std::nullopt; - for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt) + for (std::size_t blockCnt = 0; blockCnt < numNodes; ++blockCnt) { auto const bid = blockIds[blockCnt]; TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid); auto& block = mAllBlocksById[bid]; - auto const& blockKey = blockKeys[blockCnt]; - - auto [partialMatch, numMatched, matchedBlock] - = needMatch ? searchRoot->findMatchingBlock(blockKey, false, false) : std::make_tuple(false, 0, nullptr); - if (matchedBlock != nullptr) - { - // Found match - TLLM_LOG_DEBUG( - "%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(), matchedBlock->getBlockId()); - searchRoot = matchedBlock; - // TODO possible optimization: if bid != matchedBlock->getBlockId(), - // block can be freed and inserted at mFreePrimaryBlocks.begin() - } - else - { - // No match - TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", mLogPrefix.c_str(), - block->getBlockId()); - TLLM_CHECK_WITH_INFO(block->getBlockId() == bid, - "Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid)); - needMatch = false; // no matching needed for following blocks - block->setBlockKey(blockKey, static_cast(blockKey.uniqueTokens.size()) == mTokensPerBlock); - block->setPrevBlock(searchRoot); - block->setPrevBlockInSeq(searchRoot); - searchRoot->addNextBlock(blockKey, block); - - // Sanity check. The list of stored blocks should be connected. - TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back()); - + auto& lookupNode = lookupNodes[blockCnt]; + TLLM_CHECK_WITH_INFO(lookupNode.size() == 1, "Number of matching nodes must be 1"); + auto const [partialMatch, numMatched, matchedNode] = lookupNode[0]; + auto matchedBlock = matchedNode->getBlock(mWindowSize); + if (matchedBlock == nullptr) + { + auto prevNode = matchedNode->getPrevNode(); + TLLM_CHECK_WITH_INFO(prevNode != nullptr, "prevNode cannot be nullptr"); + auto prevBlock = prevNode != nullptr ? prevNode->getBlock(mWindowSize) : nullptr; + // TLLM_CHECK_WITH_INFO(prevBlock != nullptr, "prevBlock cannot be nullptr"); + TLLM_LOG_DEBUG("%s::storeBlocks - Storing block %d with prevBlock %d", mLogPrefix.c_str(), bid, + prevBlock != nullptr ? prevBlock->getBlockId() : -999); + block->setBlockKey(matchedNode->getBlockKey(), matchedNode->isFull()); + block->setPrevBlockInSeq(prevBlock); + block->setHash(); // TODO: Why this is necessary? Can it be replaced with hash from matchedNode? + block->attachToLookupNode(matchedNode, block); + matchedNode->setBlock(mWindowSize, block); storedBlocks.push_back(block); + /* TODO: Is this needed? TLLM_CHECK(block->getPrevBlockInSeq() == nullptr || block->getPrevBlockInSeq()->getHash() == searchRoot->getHash()); auto oldHash = block->getHash(); @@ -1600,16 +2251,20 @@ std::pair> WindowBlockManager::s block->setHash(newHash); } searchRoot = block; + */ numBlocksStoredForReuse++; + matchedBlock = block; } if (pinBlocks) { - searchRoot->incRefCount(); + matchedBlock->incRefCount(); } - lastStoredId = searchRoot->getBlockId(); + lastStoredId = matchedBlock->getBlockId(); } if (mEventManager) { + TLLM_LOG_DEBUG("%s;%d - mEventManager->enqueueStoredEvent(storedBlocks=%d)", __FILE__, __LINE__, + static_cast(storedBlocks.size())); mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize); } return {numBlocksStoredForReuse, lastStoredId}; @@ -1648,7 +2303,7 @@ void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeTyp TLLM_CHECK_WITH_INFO(hasFreeBlocks(beamWidth), "Can't allocate new blocks. No free blocks left."); for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { - auto block = getFreeBlock(sequence, executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt, + auto block = getFreeBlock(executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt, sequence.getTransferMode(), sequence.getDirectory()); block->incRefCount(); if (sequence.getCacheBlockIds(mWindowSize).at(beamIdx).size() == 0) @@ -1700,33 +2355,61 @@ std::deque BlockManager::getLatestEvents(std::optionalgetEvents(timeout) : std::deque{}; } -std::optional BlockManager::storeBlocksForReuse( +std::optional BlockManager::notThreadSafeStoreBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) { + // When releasing the blocks for a sequence, we store those blocks for potential reuse only if: + // - Block reuse is enabled. + // - A request was provided to this function call to identify which tokens these blocks cover + // - Beam search is NOT enabled <=> beam width == 1 + // - The sequence was not marked for use with cyclic kv-cache when it was added (when its context is too long to fit + // the max attention window). + // - The sequence did not switch to cyclic kv-cache during generation phase. + // A sequence is cyclic if its *minimum window size* is crossed, even if other window sizes were not reached. + // - The sequence is not a dummy request. + auto constexpr beamIdx = 0; + bool const storeBlocksForReuse + = sequence.getBeamWidth() == 1 && llmRequest.has_value() && !llmRequest->isDummyRequest(); + // Lookup prompt nodes once for all window block managers. + // Create nodes if no match is found. + SizeType32 numBlocksStoredForReuse = 0; std::optional lastStoredId = std::nullopt; - for (auto& [_, manager] : mWindowBlockManagers) + if (storeBlocksForReuse) { - lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); + TLLM_LOG_DEBUG( + "BlockManager::notThreadSafeStoreBlocksForReuse - prompt = " + mLookup->printPrompt(llmRequest.value())); + // Looking for an exact match, but the last block can be partially filled. + auto constexpr allowPartiallyFilledBlock = true; + auto matchedPromptNodes = mLookup->lookup(llmRequest.value(), -1, allowPartiallyFilledBlock); + TLLM_LOG_DEBUG("BlockManager::notThreadSafeStoreBlocksForReuse - nodes = " + streamPrint(matchedPromptNodes)); + // TODO: This loop can be parallelized with openmp + for (auto& [windowSize, manager] : mWindowBlockManagers) + { + auto cacheBlockIds = sequence.getCacheBlockIds(windowSize); + std::tie(numBlocksStoredForReuse, lastStoredId) + = manager.storeBlocks(matchedPromptNodes, cacheBlockIds[beamIdx], pinBlocks); + } } return lastStoredId; } +std::optional BlockManager::storeBlocksForReuse( + GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) +{ + std::lock_guard lock(mCachedBlocksRootMutex); + return notThreadSafeStoreBlocksForReuse(sequence, llmRequest, pinBlocks); +} + std::optional BlockManager::releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) { + std::lock_guard lock(mCachedBlocksRootMutex); // Released block will be stored when reuse is enabled. // Reuse is implied to be enabled if llmRequest is provided. - std::optional lastStoredId = std::nullopt; - for (auto& [_, manager] : mWindowBlockManagers) + auto lastStoredId = notThreadSafeStoreBlocksForReuse(sequence, llmRequest, pinBlocks); + for (auto& [windowSize, manager] : mWindowBlockManagers) { - if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1) - { - lastStoredId = manager.releaseBlocks(sequence, std::nullopt); - } - else - { - lastStoredId = manager.releaseBlocks(sequence, llmRequest); - } + manager.releaseBlocks(sequence); } return lastStoredId; } @@ -1774,131 +2457,57 @@ void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId) { mEvictionPolicy->releaseBlock(block); } - block = std::move(block->getPrevBlock()); + block = block->getPrevBlock(); } } void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest) { - for (auto& [_, manager] : mWindowBlockManagers) + // we store newest block for potential reuse only if: + // - Block reuse is enabled. + // - A request was provided to this function call to identify which tokens these blocks cover + // - Beam search is NOT enabled <=> beam width == 1 + // - The sequence was not marked for use with cyclic kv-cache when it was added (when its context is too long to fit + // the max attention window). + // - The sequence did not switch to cyclic kv-cache during generation phase. + // A sequence is cyclic if its *minimum window size* is crossed, even if other window sizes were not reached. + bool const storeBlocksForReuse + = sequence.getBeamWidth() == 1 && llmRequest.has_value() && !llmRequest->isDummyRequest(); + if (!storeBlocksForReuse) { - if (manager.isSWA()) - { - // SWA cannot store new blocks on the fly because the block stored - // may go OOW and be reused by another sequence. - continue; - } - manager.storeNewBlock(sequence, llmRequest); + return; } -} -void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest) -{ + // store new block once; when it first fills. + // to know if block has filled we need to know how many unique tokens there are. auto constexpr beamIdx = 0; auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx); - auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize); - - if (uniqueTokens.size() == 0) - { - return; - } - - // TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't - // have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume - // the last token's state is not filled yet. + // last token is not in kv cache, so we subtract that from usable size. auto const usableSize = static_cast(uniqueTokens.size()) - 1; if (usableSize % mTokensPerBlock != 0) { - return; - } - auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, true); - auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); - if (blockKeys.size() < 2 || cacheBlockIds[beamIdx].size() < blockKeys.size()) - { - // store all blocks - TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str()); - (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); - return; - } - - auto lastBlock = mAllBlocksById.at(cacheBlockIds[beamIdx][blockKeys.size() - 1]); - auto prevBlock = mAllBlocksById.at(cacheBlockIds[beamIdx][blockKeys.size() - 2]); - - // If the previous block is not in the radix tree, we need to store all blocks - if (prevBlock->getPrevBlock() == nullptr) - { - TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str()); - (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); + // have not crossed block boundary yet; do nothing return; } - if (lastBlock->getPrevBlock() != nullptr) + std::lock_guard lock(mCachedBlocksRootMutex); + // Lookup prompt nodes once for all window block managers + auto matchedPromptNodes = mLookup->lookup(llmRequest.value(), -1, false); + TLLM_LOG_DEBUG("BlockManager::storeNewBlock - nodes = " + streamPrint(matchedPromptNodes)); + for (auto& [windowSize, manager] : mWindowBlockManagers) { - // If the last block is not in the radix tree, we need to store all blocks - TLLM_LOG_DEBUG("%s::storeNewBlock - no need to store", mLogPrefix.c_str()); - return; + auto cacheBlockIds = sequence.getCacheBlockIds(windowSize); + [[maybe_unused]] auto const& [dummy1, dummy2] = manager.storeBlocks(matchedPromptNodes, cacheBlockIds[beamIdx]); } - TLLM_LOG_DEBUG("%s::storeNewBlock - store the last block", mLogPrefix.c_str()); - (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); -} - -std::optional WindowBlockManager::storeBlocksForReuse( - GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) -{ - auto constexpr beamIdx = 0; - auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx); - auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize); - - // TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't - // have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume - // the last token's state is not filled yet. - auto const usableSize = static_cast(uniqueTokens.size()) - 1; - auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, true); - auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); - return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second; } -std::optional WindowBlockManager::releaseBlocks( - GenerationRequest& sequence, OptionalRef llmRequest) +std::optional WindowBlockManager::releaseBlocks(GenerationRequest& sequence) { auto const requestId = sequence.getRequestId(); std::optional lastStoredId = std::nullopt; auto node = mAllocatedBlocksPerSeq.extract(requestId); TLLM_CHECK(node); auto& allocatedBlocks = node.mapped(); - if (llmRequest.has_value()) - { - // If llmRequest is provided, block store for reuse is enabled. - if (!isSequenceValidForStoreForReuse(requestId)) - { - TLLM_LOG_DEBUG( - "%s::releaseBlocks - sequence %lu does not have all blocks valid, block is not saved for reuse", - mLogPrefix.c_str(), sequence.getRequestId()); - } - else - { - if (mIsSWA) - { - TLLM_LOG_DEBUG("%s::releaseBlocks - sequence %lu is valid for store for reuse", mLogPrefix.c_str(), - sequence.getRequestId()); - } - auto const& uniqueTokens = llmRequest->getUniqueTokens(/*beamIdx=*/0); - // Only (length - 1) tokens of the sequence have their kv-state - // recorded in kv-cache. We assume the last token's state is not filled yet. - auto const usableSize = static_cast(uniqueTokens.size()) - 1; - auto blockedUniqueTokens - = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, /*allowPartial=*/true); - auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); - - std::vector cacheBlockIds(allocatedBlocks.size()); - std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(), - [](BlockPtr const& block) { return block->getBlockId(); }); - - auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds); - TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(), - sequence.getRequestId(), numBlocksStoredForReuse); - } - } for (auto it = allocatedBlocks.rbegin(); it != allocatedBlocks.rend() - sequence.getNumFrontBlocksRemoved(); ++it) { auto& block = *it; @@ -2349,19 +2958,6 @@ void KVCacheManager::addSequence( SizeType32 const numReusedBlocksPreRequest = mBlockManager.getNumReusedBlocks(); SizeType32 const numMissedBlocksPreRequest = mBlockManager.getNumMissedBlocks(); - if (!mBlockManager.isSequenceHeld(requestId)) - { - mBlockManager.holdSequence(requestId); - TLLM_LOG_DEBUG( - "[kv cache manager] Encounter new sequence %d, initialize sequence storage validity for all window sizes", - requestId); - } - else - { - TLLM_LOG_DEBUG( - "[kv cache manager] Encounter existing sequence %d, skip sequence storage validity initialization", - requestId); - } for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) { // NOTE: Caller to KVCacheManager::addSequence should deal with the chunking @@ -2373,6 +2969,7 @@ void KVCacheManager::addSequence( auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock()); if (mEnableBlockReuse) { + // TODO: Should this be effectiveInputLength - 1? mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize); } else @@ -2422,16 +3019,12 @@ void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest) void KVCacheManager::storeNewBlock(LlmRequest const& llmRequest) { - // We store newest block for potential reuse only if: - // - Beam search is NOT enabled - // - Block reuse is enabled. - auto const requestId = llmRequest.mRequestId; - auto& sequence = getSequence(requestId); - if (sequence.getBeamWidth() > 1 || !mEnableBlockReuse) + if (mEnableBlockReuse) { - return; + auto const requestId = llmRequest.mRequestId; + auto& sequence = getSequence(requestId); + mBlockManager.storeNewBlock(sequence, llmRequest); } - mBlockManager.storeNewBlock(sequence, llmRequest); } std::optional KVCacheManager::removeSequence( @@ -2455,12 +3048,6 @@ std::optional KVCacheManager::removeSequence( lastStoredId = mBlockManager.releaseBlocks(sequenceNode.mapped(), std::nullopt, pinBlocks); } } - if (mBlockManager.isSequenceHeld(requestId)) - { - mBlockManager.releaseSequence(requestId); - TLLM_LOG_DEBUG("Remove sequence %d, release sequence storage validity for all window sizes", requestId); - } - TLLM_CHECK(!mBlockManager.isSequenceHeld(requestId)); TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); return lastStoredId; } diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp index fd5758a8368..dbc3b2f93fa 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp @@ -207,47 +207,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst, } } +// +// Note about recording events to wait for cudaMempyAsync calls between blocks: +// The memory copy involves raw memory blocks, which are pointed to by the +// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex() +// as the raw memory block identifier. Earlier versions of this code used getBlockId() +// when recording events, this is just wrong. getBlockId() returns the logical block id, +// which has nothing to do with the raw memory block pointers involved in a cudaMemcpy. +// + +// +// Notes about need for synchronization: +// +// Earlier versions of this code relied on decoder implicitly syncing GPU with CPU. +// This is inherently dangerous, it is not given that decoder will always explicitly sync +// GPU with CPU for every step, a major design goal of ongoing work is to avoid this. +// To make the code future proof, we introdduce a new method SyncWithBufferManager() +// that ensures that internal copy streams will wait for prefill and decode kernels +// that have already been scheduled. +// +// Earlier versions of this code did not account for all possible cases were a new block copy +// needed to wait for a previously scheduled copy to finish. For instance, it is possible +// that two primary blocks are offloaded to the same secondary block in a single step, +// scheduling the second offloading without waiting for the first one to finish leads to +// a corrupted block after offloading. It is possible that partial reuse will copy +// from a block that is currently being onboarded, scheduling the partial copy without +// waiting for the onboarding to finish will lead to a corrupted block. To handle all +// possible cases needing synchronization we record separate events for reads and writes +// to a block. When a new block copy is scheduled, we wait for all writes to the source +// block and all reads and writes to a destination block. +// +// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence. +// Failing to do so will lead to corrupted blocks eventually. +// + void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block, std::vector const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, std::string const& directory) { - if (mode != executor::KvCacheTransferMode::DRAM - && mPendingOffloads.find(offloadBlock->getBlockId()) == mPendingOffloads.end()) + // Wait for any pending writes before reading from offloadBlock + auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex()); + if (offloadBlockPendingWriteItr != mPendingWrites.end()) { - TLLM_LOG_DEBUG("Skipping onboard for block %d because it was never previously offloaded to disk", - offloadBlock->getBlockId()); - return; + mOnboardManager.getStream().wait(offloadBlockPendingWriteItr->second); + // Don't erase, we are not changing state of offloadBlock } - - if (mPendingOffloads.find(offloadBlock->getBlockId()) != mPendingOffloads.end()) + // Wait for any pending reads before overwriting block + auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex()); + if (blockPendingReadItr != mPendingReads.end()) { - mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]); + mOnboardManager.getStream().wait(blockPendingReadItr->second); + mPendingReads.erase(blockPendingReadItr); } + // Wait for any pending writes before overwriting block + auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex()); + if (blockPendingWriteItr != mPendingWrites.end()) + { + mOnboardManager.getStream().wait(blockPendingWriteItr->second); + mPendingWrites.erase(blockPendingWriteItr); + } + copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory); + + // Record new pending read from offloadBlock + mPendingReads[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOnboardManager.getStream().record(mPendingReads[offloadBlock->getMemoryPoolBlockIndex()]); + // Record new pending write to block + mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]); } void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock, std::vector const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, std::string const& directory) { - mPendingOffloads[block->getBlockId()] = tr::CudaEvent(); + // Wait for any pending writes before reading from block + auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex()); + if (blockPendingWriteItr != mPendingWrites.end()) + { + mOffloadManager.getStream().wait(blockPendingWriteItr->second); + // Don't erase, we are not changing state of block + } + // Wait for any pending reads before overwriting offloadBlock + auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex()); + if (offloadBlockPendingReadItr != mPendingReads.end()) + { + mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second); + mPendingReads.erase(offloadBlockPendingReadItr); + } + // Wait for any pending writes before overwriting offloadBlock + auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex()); + if (offloadBlockPendingWriteItr != mPendingWrites.end()) + { + mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second); + mPendingWrites.erase(offloadBlockPendingWriteItr); + } + copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory); - mOffloadManager.getStream().record(mPendingOffloads[block->getBlockId()]); + + // Record new pending read from block + mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]); + // Record new pending write to offloadBlock + mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]); +} + +void KVCacheTransferManager::syncWithBufferManager() +{ + tr::CudaEvent readyForOffloadEvent; + mBufferManager.getStream().record(readyForOffloadEvent); + mOffloadManager.getStream().wait(readyForOffloadEvent); + + tr::CudaEvent readyForOnboardEvent; + mBufferManager.getStream().record(readyForOnboardEvent); + mOnboardManager.getStream().wait(readyForOnboardEvent); + + // Once we synchronize, clear our list of pending thransfers. + mPendingReads.clear(); + mPendingWrites.clear(); } void KVCacheTransferManager::syncTransfers() { tr::CudaEvent offloadEvent; mOffloadManager.getStream().record(offloadEvent); + mBufferManager.getStream().wait(offloadEvent); tr::CudaEvent onboardEvent; mOnboardManager.getStream().record(onboardEvent); - - mBufferManager.getStream().wait(offloadEvent); mBufferManager.getStream().wait(onboardEvent); // Once we synchronize, clear our list of pending thransfers. - mPendingOffloads.clear(); + mPendingReads.clear(); + mPendingWrites.clear(); } } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 092abd51063..ef8a51e1ff2 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -235,6 +235,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx); } + void syncTransferManagerWithBufferManager() override + { + NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager); + } + void refreshBlocks() override { NB_OVERRIDE_PURE(refreshBlocks); @@ -482,7 +487,10 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, nb::call_guard()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard()) - .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard()); + .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard()) + .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, + nb::call_guard()) + .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard()); nb::bind_vector(m, "CacheBlockIds") .def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); }) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index ee63cdbd8c8..ccdae333d3a 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -240,6 +240,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx); } + void syncTransferManagerWithBufferManager() override + { + NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager); + } + void refreshBlocks() override { PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks); @@ -486,7 +491,10 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, py::call_guard()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard()) - .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard()); + .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard()) + .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, + nb::call_guard()) + .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard()); py::enum_(m, "CacheType") .value("SELF", tbk::CacheType::kSELF) diff --git a/cpp/tests/unit_tests/batch_manager/evictionPolicyTest.cpp b/cpp/tests/unit_tests/batch_manager/evictionPolicyTest.cpp index 250a8602fd0..1b99fb9a082 100644 --- a/cpp/tests/unit_tests/batch_manager/evictionPolicyTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/evictionPolicyTest.cpp @@ -41,18 +41,20 @@ class LRUPolicyTest : public ::testing::Test public: void SetUp() override { + auto constexpr windowSize = 2; policy = std::make_shared(); std::vector allBlocksById; for (KVCacheBlock::IdType blockId = 0; blockId < NUM_PRIMARY_BLOCKS; ++blockId) { - allBlocksById.push_back(std::make_shared(blockId, tk::KVCacheIndex{blockId, false})); + allBlocksById.push_back( + std::make_shared(blockId, tk::KVCacheIndex{blockId, false}, windowSize)); } for (KVCacheBlock::IdType blockId = 0; blockId < NUM_SECONDARY_BLOCKS; ++blockId) { - allBlocksById.push_back( - std::make_shared(NUM_PRIMARY_BLOCKS + blockId, tk::KVCacheIndex{blockId, true})); + allBlocksById.push_back(std::make_shared( + NUM_PRIMARY_BLOCKS + blockId, tk::KVCacheIndex{blockId, true}, windowSize)); } policy->initialize(allBlocksById, {NUM_PRIMARY_BLOCKS, NUM_SECONDARY_BLOCKS}, std::nullopt); } diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 2e4fa943ab0..47e5a5fcd03 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -291,7 +291,6 @@ void runPartialCopyTest() GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); auto cacheBlockIds = seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx); @@ -329,7 +328,6 @@ void runPartialCopyTest() EXPECT_TRUE(blockManager.verifyQueueIntegrity(maxAttentionWindow)); } blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); // Add sequence [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] auto inputTokens1 = inputTokens; @@ -339,7 +337,6 @@ void runPartialCopyTest() GenerationRequest seq1{requestId, inputLength1, beamWidth, blockManager.getWindowSizesMetadata()}; auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 16); auto cacheBlockIds1 = seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx); @@ -364,7 +361,6 @@ void runPartialCopyTest() GenerationRequest seq2{requestId, inputLength2, beamWidth, blockManager.getWindowSizesMetadata()}; auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq2.getRequestId()); blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 11); auto cacheBlockIds2 = seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx); @@ -407,8 +403,6 @@ void runPartialCopyTest() blockManager.releaseBlocks(seq1, llmRequest1); blockManager.releaseBlocks(seq2, llmRequest2); - blockManager.releaseSequence(seq1.getRequestId()); - blockManager.releaseSequence(seq2.getRequestId()); if constexpr (transferMode == KvCacheTransferMode::GDS) fs::remove_all(directory); @@ -641,19 +635,21 @@ TEST_F(KVCacheManagerTest, FindBlocksInReuseTreeByBlockKeysTest) EXPECT_THAT(cacheBlockIds, ::testing::ElementsAreArray({0, 1, 2})); (void) kvCacheManager.removeSequence(requestId, llmRequest0); + // Store blocks 0 and 1 with tokens [0,1,2,3,4,5,6,7] and [8,9,10,11,12,13,14,15] + // Block 2 with tokens [16] is not stored because the last token cannot be reused inputTokens->pop_back(); BlockKey fullKey{*inputTokens}; auto const foundFull = kvCacheManager.findBlocksInReuseTreeByBlockKey(fullKey, maxAttentionWindow); ASSERT_NE(foundFull, nullptr); auto const& lastBlock = foundFull; + // lastBlock = [8,9,10,11,12,13,14,15] // Check the chain back to previous blocks - auto const prev2 = lastBlock->getPrevBlock(); - ASSERT_NE(prev2, nullptr); - auto const prev1 = prev2->getPrevBlock(); + auto const prev1 = lastBlock->getPrevBlock(); // prev1 = [0,1,2,3,4,5,6,7] ASSERT_NE(prev1, nullptr); - EXPECT_EQ(prev1->getPrevBlock(), nullptr); + auto const prev0 = prev1->getPrevBlock(); // prev0 = nullptr + ASSERT_EQ(prev0, nullptr); } #ifdef ENABLE_FP4 @@ -741,7 +737,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) auto constexpr beamIdx = 0; auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -755,7 +750,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // blocks 0, 1, 2 are stored for reuse (blocks contain [0, 1, 2, 3], [4, 5, 6, 7], [8, 9]) blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -768,7 +762,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // reuse blocks 0, 1 ([0, 1, 2, 3], [4, 5, 6, 7]) and get new block 3 auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); @@ -780,7 +773,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // block 3 matches block 2 and will be freed (blocks contain [8, 9]) blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq1.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -793,7 +785,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens0, samplingConfig, isStreaming); promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), promptLen0 - 1); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -807,7 +798,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens1, samplingConfig, isStreaming); promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4})); @@ -817,12 +807,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // block 2 is stored for reuse (block contains [8]). nb! Last token of last block is never stored blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); // block 4 is stored for reuse (block contains [8, 9]). nb! Last token of last block is never stored blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq1.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -840,7 +828,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // reuse block 0 ([0, 1, 2, 3]), get new block 5 auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq2.getRequestId()); blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 5})); @@ -862,7 +849,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // reuse blocks 0, 1, 4(p) ([0, 1, 2, 3], [4, 5, 6, 7], [8, 9]) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq3.getRequestId()); blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), numTokens - 1); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4})); @@ -875,10 +861,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // block 5 is not stored since it is last block and has only one token blockManager.releaseBlocks(seq2, llmRequest2); - blockManager.releaseSequence(seq2.getRequestId()); // block 4 is stored for reuse (block contains [8, 9]). nb! Last token of last block not stored blockManager.releaseBlocks(seq3, llmRequest3); - blockManager.releaseSequence(seq3.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -895,7 +879,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // reuse blocks 0, 1, 4(p) ([0, 1, 2, 3], [4, 5, 6, 7], [8,9]) auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq4.getRequestId()); blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), promptLen4 - 1); EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4})); @@ -911,7 +894,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // blocks 0 and 1 ([0, 1, 2, 3], [4, 5, 6, 7]) are already stored, // block 4 is freed blockManager.releaseBlocks(seq4, llmRequest4Short); - blockManager.releaseSequence(seq4.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -925,7 +907,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) llmRequest4 = std::make_shared(requestId, maxNewTokens, inputTokens4, samplingConfig, isStreaming); promptLen4 = llmRequest4->getNumTokens(beamIdx); numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq4.getRequestId()); blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), promptLen4 - 2); EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -935,7 +916,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); blockManager.releaseBlocks(seq4, llmRequest4); - blockManager.releaseSequence(seq4.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -952,7 +932,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // no reuse, all blocks need to be freed auto promptLen5 = llmRequest5->getNumTokens(beamIdx); auto numContextBlocks5 = tc::ceilDiv(promptLen5, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq5.getRequestId()); blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow); llmRequest5->addNewToken(0, beamIdx); EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 1); // incidental reuse @@ -961,7 +940,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), 0); blockManager.releaseBlocks(seq5, llmRequest5); - blockManager.releaseSequence(seq5.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -977,7 +955,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // no reuse, all blocks need to be freed auto promptLen6 = llmRequest6->getNumTokens(beamIdx); auto numContextBlocks6 = tc::ceilDiv(promptLen6, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq6.getRequestId()); blockManager.addSequence(seq6, promptLen6, numContextBlocks6, *llmRequest6, maxAttentionWindow); llmRequest6->addNewToken(0, beamIdx); // no reuse occurs because we are unable to reuse last input token and inputLength6 == 1. @@ -987,7 +964,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - 1); blockManager.releaseBlocks(seq6, llmRequest6); - blockManager.releaseSequence(seq6.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } @@ -1047,7 +1023,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto constexpr beamIdx = 0; auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1061,7 +1036,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) // blocks 0, 1, 2 are stored for reuse (block 2 contains [(2, 0), (3, 0)]) blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1079,7 +1053,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) // reuse blocks 0, 1 and get new block 3 auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); @@ -1090,7 +1063,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) // block 3 matches block 2 and will be freed blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq1.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1105,7 +1077,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds, numReturnSequences); promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); llmRequest0->addNewToken(3, beamIdx); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 2 * tokensPerBlock); @@ -1126,7 +1097,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds1, numReturnSequences); promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1135,12 +1105,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks - 1); blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); // blocks 2 is stored for reuse (block contains [(2, 0), (3, 0), (4, 0)]) blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq1.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1160,7 +1128,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) // no reuse, get new block 5, 6, 7 auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq2.getRequestId()); blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7})); @@ -1186,7 +1153,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) // reuse block 0, get new block 8, 9 auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq3.getRequestId()); blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 8, 9})); @@ -1198,8 +1164,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) blockManager.releaseBlocks(seq2, llmRequest2); blockManager.releaseBlocks(seq3, llmRequest3); - blockManager.releaseSequence(seq2.getRequestId()); - blockManager.releaseSequence(seq3.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } @@ -1263,7 +1227,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto constexpr beamIdx = 0; auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1282,7 +1245,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) // Block 1: [104, 105, 0, 1] ← Contains multimodal (104, 105) // Block 2: [2, 3, 4] ← No multimodal blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1300,7 +1262,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) // should reuse blocks 0, 1 and get new block 3 auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); @@ -1310,7 +1271,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); // block 3 matches block 2 and will be freed blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq1.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1335,7 +1295,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) // no reuse, get new blocks 4, 5, 6 auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq2.getRequestId()); blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 6})); @@ -1368,7 +1327,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) // reuse block 0, get new blocks 7, 8 auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq3.getRequestId()); blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock); // only reuse block 0 [100, 101, 102, 103] with same hash/offset @@ -1382,8 +1340,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) // clean up blockManager.releaseBlocks(seq2, llmRequest2); blockManager.releaseBlocks(seq3, llmRequest3); - blockManager.releaseSequence(seq2.getRequestId()); - blockManager.releaseSequence(seq3.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } @@ -1439,7 +1395,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); // get new blocks 0, 1, 2 ([0,1,2,3], [4,5,6,7], [8]) - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1453,7 +1408,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) // store blocks 0, 1, 2 for reuse ([0,1,2,3], [4,5,6,7], [8,9]) blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1469,7 +1423,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) // reuse blocks 0, 1 and get new block 3 auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); @@ -1480,7 +1433,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) // store block 3 for reuse ([8,9]) blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq1.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1493,7 +1445,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); // nb! addNewToken adds new generated token, number of input tokens stay the same. // calling addNewToken before addSequence potentially triggers this error message: @@ -1514,7 +1465,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); // reuse 0, 1, 2(p) ([0,1,2,3], [4,5,6,7], [8]) - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1524,12 +1474,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) // store block 4 for reuse ([8]) blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); // blocks 2 is stored for reuse (block contains [8, 9]). nb! Last token of last block is not stored blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq1.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1546,7 +1494,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) // no reuse, get new block 5, 6, 7 auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq2.getRequestId()); blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); // no reuse expected. Input tokens match blocks 0 and 1, but lora task id differs. EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); @@ -1558,7 +1505,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); // store blocks 5, 6, 7 for reuse ([0,1,2,3], [4,5,6,7], [8]) with loraTaskId 1 blockManager.releaseBlocks(seq2, llmRequest2); - blockManager.releaseSequence(seq2.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1575,7 +1521,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) // reuse blocks 5, 6, 7(p) ([0,1,2,3], [4,5,6,7], [8]) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq3.getRequestId()); blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), promptLen3 - 2); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7})); @@ -1586,7 +1531,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); // store block 7 for reuse ([8,9]) with loraTaskId 1 blockManager.releaseBlocks(seq3, llmRequest3); - blockManager.releaseSequence(seq3.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1605,7 +1549,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) // reuse blocks 0, get new block 8 auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq4.getRequestId()); blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 8})); @@ -1616,7 +1559,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); // blocks 8 is stored with [4] and loraTaskId 0 blockManager.releaseBlocks(seq4, llmRequest4); - blockManager.releaseSequence(seq4.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1630,7 +1572,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) // no reuse, get new block 9, 10, 11 auto promptLen5 = llmRequest5->getNumTokens(beamIdx); auto numContextBlocks5 = tc::ceilDiv(promptLen5, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq5.getRequestId()); blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow); EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 0); EXPECT_THAT(seq5.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({9, 10, 11})); @@ -1641,7 +1582,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); // blocks 9, 10, 11 are stored without loraTaskId blockManager.releaseBlocks(seq5, llmRequest5); - blockManager.releaseSequence(seq5.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } @@ -1701,7 +1641,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto constexpr beamIdx = 0; auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1715,7 +1654,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) // blocks 0, 1, 2 are stored for reuse (block 2 contains [(2, 0), (3, 0)] with loraTaskId 1) blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1734,7 +1672,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) // no reuse, get new block 3, 4, 5 auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5})); @@ -1745,7 +1682,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) // blocks 3, 4, 5 are stored for reuse (block 5 contains [(2, 0), (3, 0)] with loraTaskId 2) blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq1.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1760,7 +1696,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); // reuse blocks 0, 1 and get new block 6 - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); llmRequest0->addNewToken(3, beamIdx); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 2 * tokensPerBlock); @@ -1781,7 +1716,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds1); promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5})); @@ -1790,11 +1724,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks * 2); blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq1.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1814,7 +1746,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) // no reuse, get new block 7, 8, 9 auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq2.getRequestId()); blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9})); @@ -1840,7 +1771,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) // reuse block 0, get new block 10, 11 auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq3.getRequestId()); blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 10, 11})); @@ -1865,7 +1795,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) // reuse block 3, get new block 12, 13 auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq4.getRequestId()); blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 12, 13})); @@ -1878,9 +1807,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) blockManager.releaseBlocks(seq2, llmRequest2); blockManager.releaseBlocks(seq3, llmRequest3); blockManager.releaseBlocks(seq4, llmRequest4); - blockManager.releaseSequence(seq2.getRequestId()); - blockManager.releaseSequence(seq3.getRequestId()); - blockManager.releaseSequence(seq4.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } @@ -1944,7 +1870,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto constexpr beamIdx = 0; auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1960,7 +1885,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) // Release blocks to make them available for reuse blockManager.releaseBlocks(seq0, llmRequest0); - blockManager.releaseSequence(seq0.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -1982,7 +1906,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) // Should NOT reuse blocks despite same tokens, because cache_salt_id is different auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); // No reuse, starts from scratch EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5})); @@ -1994,7 +1917,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) // Release blocks blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq1.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -2015,7 +1937,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) // SHOULD reuse blocks because both tokens and cache_salt_id match auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq2.getRequestId()); blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 3,4 EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 6})); @@ -2027,7 +1948,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) // Release blocks blockManager.releaseBlocks(seq2, llmRequest2); - blockManager.releaseSequence(seq2.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -2049,7 +1969,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) // Should NOT reuse blocks from any previous request because cache_salt_id is different auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq3.getRequestId()); blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 0); // No reuse EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9})); @@ -2076,7 +1995,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) // Should reuse blocks from request0 (blocks 0,1) because both have no cache_salt_id auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq4.getRequestId()); blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 0,1 EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 10})); @@ -2090,8 +2008,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) // Clean up blockManager.releaseBlocks(seq3, llmRequest3); blockManager.releaseBlocks(seq4, llmRequest4); - blockManager.releaseSequence(seq3.getRequestId()); - blockManager.releaseSequence(seq4.getRequestId()); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } @@ -2196,9 +2112,14 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) KvCacheRetentionConfig({KvCacheRetentionConfig::TokenRangeRetentionConfig(0, 4, 90), KvCacheRetentionConfig::TokenRangeRetentionConfig(4, 8, 10)}, 20)); + auto kvCacheRetentionconfig = llmRequest0->getKvCacheRetentionConfig(); + if (kvCacheRetentionconfig.has_value()) + { + TLLM_LOG_DEBUG( + "%s%d - KvCacheRetentionConfig = %s", __FILE__, __LINE__, kvCacheRetentionconfig.value().print().c_str()); + } GenerationRequest seq0{0, inputLength0, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks0 = tc::ceilDiv(inputLength0, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq0.getRequestId()); blockManager.addSequence(seq0, llmRequest0->getNumTokens(0), numContextBlocks0, *llmRequest0, maxAttentionWindow); // Add another sequence with different tokens, at a low priority @@ -2207,14 +2128,11 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) auto llmRequest1 = std::make_shared(1, maxNewTokens, inputTokens1, samplingConfig, isStreaming); GenerationRequest seq1{1, inputLength1, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks1 = tc::ceilDiv(inputLength1, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq1.getRequestId()); blockManager.addSequence(seq1, llmRequest1->getNumTokens(0), numContextBlocks1, *llmRequest1, maxAttentionWindow); // Release both sequences blockManager.releaseBlocks(seq0, llmRequest0); blockManager.releaseBlocks(seq1, llmRequest1); - blockManager.releaseSequence(seq0.getRequestId()); - blockManager.releaseSequence(seq1.getRequestId()); // Add and then release another sequence auto inputTokens2 = std::make_shared(VecTokens{16, 17, 18, 19, 20, 21, 22, 23}); @@ -2224,10 +2142,8 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) KvCacheRetentionConfig({KvCacheRetentionConfig::TokenRangeRetentionConfig(0, std::nullopt, 20)}, 20)); GenerationRequest seq2{2, inputLength2, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks2 = tc::ceilDiv(inputLength2, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq2.getRequestId()); blockManager.addSequence(seq2, llmRequest2->getNumTokens(0), numContextBlocks2, *llmRequest2, maxAttentionWindow); blockManager.releaseBlocks(seq2, llmRequest2); - blockManager.releaseSequence(seq2.getRequestId()); // Check that request 1 blocks were overwritten auto inputTokens3 = std::make_shared(VecTokens{8, 9, 10, 11, 12, 13, 14, 15}); @@ -2235,13 +2151,11 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) auto llmRequest3 = std::make_shared(3, maxNewTokens, inputTokens3, samplingConfig, isStreaming); GenerationRequest seq3{3, inputLength3, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks3 = tc::ceilDiv(inputLength3, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq3.getRequestId()); blockManager.addSequence(seq3, llmRequest3->getNumTokens(0), numContextBlocks3, *llmRequest3, maxAttentionWindow); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 4); blockManager.releaseBlocks(seq3, llmRequest3); - blockManager.releaseSequence(seq3.getRequestId()); EXPECT_EQ(blockManager.getNumFreeBlocks(), 4); // Check that request 0 blocks weren't overwritten @@ -2250,7 +2164,6 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) auto llmRequest4 = std::make_shared(4, maxNewTokens, inputTokens4, samplingConfig, isStreaming); GenerationRequest seq4{4, inputLength3, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks4 = tc::ceilDiv(inputLength4, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq4.getRequestId()); blockManager.addSequence(seq4, llmRequest4->getNumTokens(0), numContextBlocks4, *llmRequest4, maxAttentionWindow); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 4); @@ -2261,7 +2174,6 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) auto llmRequest5 = std::make_shared(5, maxNewTokens, inputTokens5, samplingConfig, isStreaming); GenerationRequest seq5{5, inputLength5, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks5 = tc::ceilDiv(inputLength5, blockManager.getTokensPerBlock()); - blockManager.holdSequence(seq5.getRequestId()); blockManager.addSequence(seq5, llmRequest5->getNumTokens(0), numContextBlocks5, *llmRequest5, maxAttentionWindow); EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 0); @@ -2299,71 +2211,69 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeBlockPriorityTest) EXPECT_EQ(kvCacheManager.getNumFreeBlocks(), 8); - // Uses 3 blocks 0, 1, 2 which contain [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11] auto inputTokens0 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); auto const inputLength0 = static_cast(inputTokens0->size()); auto llmRequest0 = std::make_shared(0, maxNewTokens, inputTokens0, samplingConfig, isStreaming); llmRequest0->setKvCacheRetentionConfig( KvCacheRetentionConfig({KvCacheRetentionConfig::TokenRangeRetentionConfig(0, std::nullopt, 90)}, 90)); kvCacheManager.addSequence(0, inputLength0, beamWidth, llmRequest0); + // reserve blocks 0, 1 and 2 with tokens [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11] and priority 90 // 5 blocks available. EXPECT_EQ(kvCacheManager.getNumFreeBlocks(), 5); - // Add a token to request 0, which occupies a new block 3. + // Add a token to request 0 kvCacheManager.addToken(0); - llmRequest0->addNewToken(0, 0); // block 3 contains [0] + llmRequest0->addNewToken(0, 0); + // reserve block 3 with tokens [0] and priority 90 // 4 blocks left. EXPECT_EQ(kvCacheManager.getNumFreeBlocks(), 4); - // uses up 3 more blocks 4, 5, 6. [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23] auto inputTokens1 = std::make_shared(VecTokens{12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); auto const inputLength1 = static_cast(inputTokens1->size()); auto llmRequest1 = std::make_shared(1, maxNewTokens, inputTokens1, samplingConfig, isStreaming); llmRequest1->setKvCacheRetentionConfig( KvCacheRetentionConfig({KvCacheRetentionConfig::TokenRangeRetentionConfig(0, std::nullopt, 5)}, 5)); kvCacheManager.addSequence(1, inputLength1, beamWidth, llmRequest1); + // reserve blocks 4, 5 and 6 with tokens [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23] and priority 5 // one block left. EXPECT_EQ(kvCacheManager.getNumFreeBlocks(), 1); - // add another token, which occupies another new block + // add another token kvCacheManager.addToken(1); - llmRequest1->addNewToken(0, 0); // block 7 contains [0] + llmRequest1->addNewToken(0, 0); + // reserve block 7 with priority 5 // no block available. EXPECT_EQ(kvCacheManager.getNumFreeBlocks(), 0); // remove both sequences, blocks get stored - // leaf block 3 (priority 90), context blocks 2, 1, 0 (priority 5) (void) kvCacheManager.removeSequence(0, llmRequest0); - // leaf block 7 (priority 5), context blocks 6, 5, 4 (priority 90) + // store blocks 0, 1, 2 with priority 90. block 3 is released without being stored, but still has priority 90. (void) kvCacheManager.removeSequence(1, llmRequest1); + // store blocks 4, 5, 6 with priority 5. block 7 is released without being stored, but still has priority 5. // all blocks are available again. EXPECT_EQ(kvCacheManager.getNumFreeBlocks(), 8); // no reuse, blocks are evicted by new request: - // evict block 7 (lowest priority, first released block) - // evict block 6 (lowest priority, second released block) - // evict block 5 (lowest priority, third released block) - // uses up 3 blocks 7, 6, 5. [24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35] + // uses up 3 blocks 2, 1, 0. [24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34] auto inputTokens2 = std::make_shared(VecTokens{24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}); auto const inputLength2 = static_cast(inputTokens2->size()); auto llmRequest2 = std::make_shared(2, maxNewTokens, inputTokens2, samplingConfig, isStreaming); kvCacheManager.addSequence(2, inputLength2, beamWidth, llmRequest2); - // leaf block 2 (priority 35), context blocks 3, 7 (priority 35) - (void) kvCacheManager.removeSequence(2, llmRequest2); + // no reuse, reserve blocks 5, 6, 7 since they have lower priority than 0, 1, 2 and 3. + kvCacheManager.removeSequence(2, llmRequest2); - // Uses 3 blocks 0, 1, 2 which contain [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11] + // reuse blocks 0, 1 and 2 with tokens [0,1,2,3] [4,5,6,7] [8,9,10,11] auto inputTokens3 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); auto const inputLength3 = static_cast(inputTokens3->size()); auto llmRequest3 = std::make_shared(3, maxNewTokens, inputTokens3, samplingConfig, isStreaming); kvCacheManager.addSequence(3, inputLength3, beamWidth, llmRequest3); - // Reuse block 0, 1, and partial reuse block 2. (maximum reuse is inputLength - 1) - // Two blocks reused, the third block partially reused. + // Reused 11 out of the 12 input tokens. EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 11); } @@ -2477,15 +2387,18 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeTimedEvictionTest) KvCacheRetentionConfig({KvCacheRetentionConfig::TokenRangeRetentionConfig(0, std::nullopt, 50)}, 50)); // Set all blocks to priority 50. kvCacheManager.addSequence(0, inputLength0, beamWidth, llmRequest0); + // Reserve new blocks 0,1,2 kvCacheManager.storeContextBlocks(*llmRequest0); // Occupy a new block, block 3, adding 3 tokens to block 3. // [1, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [0, 0, 0] for (int i = 0; i < 3; i++) { + // i == 0 will reserve new block 3 for generated tokens kvCacheManager.addToken(0); llmRequest0->addNewToken(0, 0); } - (void) kvCacheManager.removeSequence(0, llmRequest0); + kvCacheManager.removeSequence(0, llmRequest0); + // Store blocks 0,1,2,3 with tokens [1,1,2,3] [4,5,6,7] [8,9,10,11] [0,0,0] at priority 50. } { // 12 tokens, occupy 3 blocks 4, 5, 6. @@ -2496,17 +2409,30 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeTimedEvictionTest) llmRequest1->setKvCacheRetentionConfig(KvCacheRetentionConfig( {}, KvCacheRetentionConfig::kMaxRetentionPriority, 20ms)); // Set decode blocks to max priority for 20ms. kvCacheManager.addSequence(1, inputLength1, beamWidth, llmRequest1); + // Reserve new blocks [4,5,6] kvCacheManager.storeContextBlocks(*llmRequest1); - // Occupy a new block, block 3, adding 3 tokens to block 3. - // [1, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [0, 0, 0] + // Occupy a new block, block 7, adding 3 tokens to block 7. + // [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [0, 0, 0] for (int i = 0; i < 3; i++) { + // i == 0 will reserve block 7 for generated tokens kvCacheManager.addToken(1); llmRequest1->addNewToken(0, 0); } - (void) kvCacheManager.removeSequence(1, llmRequest1); + kvCacheManager.removeSequence(1, llmRequest1); + // Store blocks 4,5,6,7 with tokens [0,1,2,3] [4,5,6,7] [8,9,10,11] [1,1,1] + // Assigned priorities are 4(35) 5(35) 6(35) 7(100). Block 7 retains max priority for 20ms + // Free queues at this point are: + // 35 : 4,5,6,8 + // 50 : 0,1,2,3 + // 100 : 7 } + // Allow block 7's max priority to expire, demoting block from priority 100 to 35. + // Note that demoting block 7 priority puts block 7 at the back of the free queue, + // Free queues at this point are: + // 35 : 7,4,5,6,8 + // 50 : 0,1,2,3 std::this_thread::sleep_for(std::chrono::milliseconds(50)); kvCacheManager.refreshBlocks(); @@ -2515,13 +2441,19 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeTimedEvictionTest) auto const inputLength2 = static_cast(inputTokens2->size()); auto llmRequest2 = std::make_shared(2, maxNewTokens, inputTokens2, samplingConfig, isStreaming); kvCacheManager.addSequence(2, inputLength2, beamWidth, llmRequest2); - (void) kvCacheManager.removeSequence(2, llmRequest2); + // No unused blocks available. Evict blocks 8,6 and reserve for new request. + kvCacheManager.removeSequence(2, llmRequest2); + // Store blocks [8,6] for reuse. + // Free queues at this point are: + // 35 : 8,6,7,4,5 + // 50 : 0,1,2,3 // 12 tokens, reusing block 4, 5. Block 6 is overwritten so no reuse. auto inputTokens3 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); auto const inputLength3 = static_cast(inputTokens3->size()); auto llmRequest3 = std::make_shared(3, maxNewTokens, inputTokens3, samplingConfig, isStreaming); kvCacheManager.addSequence(3, inputLength3, beamWidth, llmRequest3); + // Reuse block 4,5. Evict blocks 7 EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 8); } @@ -2556,37 +2488,79 @@ TEST_F(KVCacheManagerTest, KVCacheManagerSecondaryBlockPrimaryChildTest) 0, stream, maxSequenceLength, true, onboardBlocks); kvCacheManager.allocatePools(false); + auto const& blockManager = kvCacheManager.getBlockManager(); + auto inputTokens0 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); auto const inputLength0 = static_cast(inputTokens0->size()); auto llmRequest0 = std::make_shared(0, maxNewTokens, inputTokens0, samplingConfig, isStreaming); - // 12 tokens, get block 0, 1, 2 - // [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11] + TLLM_LOG_DEBUG("%s;%d\n%s", __FILE__, __LINE__, blockManager.printFreeQueues(maxAttentionWindow).c_str()); + // Free queues: + // P : 35 : 0,1,2,3 + // S : 35 : 4,5,6,7 kvCacheManager.addSequence(0, inputLength0, beamWidth, llmRequest0); + TLLM_LOG_DEBUG("%s;%d\n%s", __FILE__, __LINE__, blockManager.printFreeQueues(maxAttentionWindow).c_str()); + // Reserve blocks 0,1,2 for tokens [0,1,2,3] [4,5,6,7] [8,9,10,11] + // Free queues: + // P : 35 : 3 + // S : 35 : 4,5,6,7 (void) kvCacheManager.removeSequence(0, llmRequest0); - // store blocks 0, 1, 2 for reuse ([0,1,2,3], [4,5,6,7], [8,9,10]) + TLLM_LOG_DEBUG("%s;%d\n%s", __FILE__, __LINE__, blockManager.printFreeQueues(maxAttentionWindow).c_str()); + // Store blocks 0,1,2 + // Free queues: + // P : 35 : 3,2,1,0 + // S : 35 : 4,5,6,7 + // Cache: + // 0=[0,1,2,3] -> 1=[4,5,6,7] -> 2=[8,9,10] // Offload the last two blocks of llmRequest0 to secondary memory auto inputTokens1 = std::make_shared(VecTokens{1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); auto const inputLength1 = static_cast(inputTokens1->size()); auto llmRequest1 = std::make_shared(1, maxNewTokens, inputTokens1, samplingConfig, isStreaming); - // Uses blocks 3, 4, 5, block 2 and 1 to be offloaded to secondary - // Block 4 is now in primary (replacing 2) - // Block 5 is now in primary (replacing 1) kvCacheManager.addSequence(1, inputLength1, beamWidth, llmRequest1); + TLLM_LOG_DEBUG("%s;%d\n%s", __FILE__, __LINE__, blockManager.printFreeQueues(maxAttentionWindow).c_str()); + // No reuse possible + // Reserve block 3 + // Offload block 2 to 4. Reserve block 4 + // Offload block 1 to 5. Reserve block 5 + // Free queues: + // P : 35 : 0 + // S : 45 : 6,7,2,1 + // Cache: + // 0=[0,1,2,3] -> 1=[4,5,6,7] -> 2=[8,9,10] (void) kvCacheManager.removeSequence(1, llmRequest1); - // store blocks 3, 4, 5 for reuse ([1,1,2,3], [4,5,6,7], [8,9,10]) + TLLM_LOG_DEBUG("%s;%d\n%s", __FILE__, __LINE__, blockManager.printFreeQueues(maxAttentionWindow).c_str()); + // Store blocks 3,4,5 + // Free queues: + // P : 35 : 0,5,4,3 + // S : 45 : 6,7,2,1 + // Cache: + // 0=[0,1,2,3] -> 1=[4,5,6,7] -> 2=[8,9,10] + // 3=[1,1,2,3] -> 4=[4,5,6,7] -> 5=[8,9,10] - // Match the middle block of request 0 - // Uses block 6, block 0 is offloaded to secondary - // Block 6 copies content from block 0 to itselg. auto inputTokens2 = std::make_shared(VecTokens{0, 1, 2, 3}); auto const inputLength2 = static_cast(inputTokens2->size()); auto llmRequest2 = std::make_shared(2, maxNewTokens, inputTokens2, samplingConfig, isStreaming); - // reuse block 0 kvCacheManager.addSequence(2, inputLength2, beamWidth, llmRequest2); + TLLM_LOG_DEBUG("%s;%d\n%s", __FILE__, __LINE__, blockManager.printFreeQueues(maxAttentionWindow).c_str()); + // Partially reuse block 0 by copy. This involves the following steps: + // getFreeBlock to partially copy tokens from block 0 into: + // Offload block 0 to 6 (primary -> secondary) + // return block 6 + // partially copy block 0 to 6 (secondary -> primary) + // reserve block 6 + // Block 0 is now in secondary, return block 6. + // TODO: This round-trip to host memory and back is stupid. + // It happens because the source (0) and destination chosen by getFreeBlock (0) are the same. Check for this. + // Free queues: + // P : 35 : 5,4,3 + // S : 45 : 7,2,1,0 + // Cache: + // 0=[0,1,2,3] -> 1=[4,5,6,7] -> 2=[8,9,10] + // 3=[1,1,2,3] -> 4=[4,5,6,7] -> 5=[8,9,10] EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 3); kvCacheManager.storeContextBlocks(*llmRequest2); + // Nothing stored since we only have 4 prompt tokens and we need at least 5 to get one full reusable block // Add a decode block that matches the contents of seq 0 block 1, add a unique decode block // The 4 tokens added has the same content as block 1. @@ -2595,30 +2569,78 @@ TEST_F(KVCacheManagerTest, KVCacheManagerSecondaryBlockPrimaryChildTest) llmRequest2->addNewToken(token, 0); kvCacheManager.addToken(2); } - // Add 2 more tokens, occupying another block + TLLM_LOG_DEBUG("%s;%d\n%s", __FILE__, __LINE__, blockManager.printFreeQueues(maxAttentionWindow).c_str()); + // Offload block 5 to 7. Reserve block 7 + // Free queues: + // P : 35 : 4,3 + // S : 45 : 2,1,0,5 + // Cache: + // 0=[0,1,2,3] -> 1=[4,5,6,7] -> 2=[8,9,10] + // 3=[1,1,2,3] -> 4=[4,5,6,7] -> 5=[8,9,10] + + // add two more tokens. llmRequest2->addNewToken(0, 0); kvCacheManager.addToken(2); - llmRequest2->addNewToken(0, 0); kvCacheManager.addToken(2); - // The middle block remains in secondary, but the third block is in primary - // FIXME: When removing the sequence, we should observe whether released - // blocks can replace itself as the block reused in the search tree if - // the matching block is currently in secondary memory. We can release the - // block in secondary if so. - // If we do this, then the context current position at the bottom of this - // unit test will be 9 because then the block content [4,5,6,7] can be - // found and reused. + TLLM_LOG_DEBUG("%s;%d\n%s", __FILE__, __LINE__, blockManager.printFreeQueues(maxAttentionWindow).c_str()); + // Evict 2 + // Offload block 4 to 2. Reserve block 2 + // Free queues: + // P : 35 : 3 + // S : 45 : 1,0,5,4 + // Cache: + // 0=[0,1,2,3] -> 1=[4,5,6,7] -> nil + // 3=[1,1,2,3] -> 4=[4,5,6,7] -> 5=[8,9,10] + (void) kvCacheManager.removeSequence(2, llmRequest2); + TLLM_LOG_DEBUG("%s;%d\n%s", __FILE__, __LINE__, blockManager.printFreeQueues(maxAttentionWindow).c_str()); + // Store blocks 6,7,2. Blocks 6, 7 are not stored since same state is already stored in tree + // Free queues: + // P : 35 : 3,2,7,6 + // S : 45 : 1,0,5,4 + // Cache: + // 0=[0,1,2,3] -> 1=[4,5,6,7] -> 2=[0] + // 3=[1,1,2,3] -> 4=[4,5,6,7] -> 5=[8,9,10] - // 10 tokens, reusing the block 0 only because when we want to acquire - // the second block, contents of block 3 will be offloaded to block 1. auto inputTokens3 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 0, 0}); auto const inputLength3 = static_cast(inputTokens3->size()); auto llmRequest3 = std::make_shared(3, maxNewTokens, inputTokens3, samplingConfig, isStreaming); kvCacheManager.addSequence(3, inputLength3, beamWidth, llmRequest3); - // Check out FIXME note above. If addressed, this should be 9. - EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 4); + TLLM_LOG_DEBUG("%s;%d\n%s", __FILE__, __LINE__, blockManager.printFreeQueues(maxAttentionWindow).c_str()); + // Reuse blocks 0,1,2 with tokens [0,1,2,3] [4,5,6,7] [0] + // The following steps are involved: + // + // Claim blocks 0,1,2 so they cannot be evicted implicitly + // P : 35 : 3,7,6 + // S : 35 : 5,4 + // + // Matched block 0 + // Block 0 needs onboarding. + // Evict block 3. Block 3 needs offloading. + // Offload block 3 + // Evict block 5 + // Copy block 3 to 5. + // Blocks 3 and 5 swap pointers (block 3 is now secondary, block 5 is free primary) + // Copy block 0 to block 5 + // Blocks 0 and 5 swap pointers (block 0 now primary, block 5 is free secondary) + // P : 35 : 7,6 + // S : 35 : 4,3,5 + // + // Matched block 1 + // Block 1 needs onboarding + // Evict block 7. Block 7 does not need offloading. + // Copy block 1 to 7. + // Blocks 1 and 7 swap pointers (block 1 now primary, block 7 is free secondary) + // P : 35 : 6 + // S : 35 : 4,3,5,7 + // + // Matched block 2 + // Block 2 does not need onboarding + // P : 35 : 6 + // S : 35 : 4,3,5,7 + // + EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 9); } TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockTest) @@ -2735,6 +2757,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockWithDependentTest) auto llmRequest0 = std::make_shared(requestId0, maxNewTokens, inputTokens0, samplingConfig, isStreaming); kvCacheManager.addSequence(requestId0, inputLength0, beamWidth, llmRequest0); + // Reserve block 0 with tokens [0,1,2,3] kvCacheManager.storeContextBlocks(*llmRequest0); GenerationRequest const& seq0 = kvCacheManager.getSequence(requestId0); @@ -2744,6 +2767,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockWithDependentTest) llmRequest0->addNewToken(i, beamIdx); kvCacheManager.addToken(requestId0); } + // Reserve blocks 1,2 for generated tokens // Verify auto cacheBlockIds0 = seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx); @@ -2752,7 +2776,8 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockWithDependentTest) // Raise priority of middle block to prevent offloading auto const& blockManager = kvCacheManager.getBlockManager(); auto middleBlock = blockManager.getBlockById(cacheBlockIds0[1], maxAttentionWindow); - middleBlock->setPriority(75); + middleBlock->setPriority(0); + // Lower priority of block 1 to zero // Create another sequence with one block worth of context tokens (no reuse). // 4 tokens, occupying block 3 @@ -2762,51 +2787,60 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockWithDependentTest) auto llmRequest1 = std::make_shared(requestId1, maxNewTokens, inputTokens1, samplingConfig, isStreaming); kvCacheManager.addSequence(requestId1, inputLength1, beamWidth, llmRequest1); + // seq1: Reserve block 3 for tokens [100,101,102,103] kvCacheManager.storeContextBlocks(*llmRequest1); GenerationRequest const& seq1 = kvCacheManager.getSequence(requestId1); // Verify that all primary blocks are in use EXPECT_EQ(blockManager.getNumFreeBlocks(), 0); - // Free first sequence - (void) kvCacheManager.removeSequence(requestId0, llmRequest0); + // Free seq0 + kvCacheManager.removeSequence(requestId0, llmRequest0); + // Store blocks 0,1,2 for reuse. + // Free queues are now: + // 0 : 1 + // 35 : 0,2 // Verify that 3 primary blocks are free. // Since block 1 has higher priority, block 2 and 0 will be used first. EXPECT_EQ(blockManager.getNumFreeBlocks(), 3); - // Write one generated token to second sequence. This will prompt block 2 to be offloaded. - // Block 4 will be in primary (replacing block 2) + // Write one generated token to seq1. This will cause block 1 to be evicted + // without offloading since priority is lower than minimum required for offloading. llmRequest1->addNewToken(104, beamIdx); kvCacheManager.addToken(requestId1); + // seq1: Reserve block 1 for tokens [104] - // Verify that block 2 has block 1 as parent - auto block2 = blockManager.getBlockById(2, maxAttentionWindow); - EXPECT_TRUE(block2->getPrevBlock() != nullptr); - EXPECT_EQ(block2->getPrevBlock()->getBlockId(), 1); - EXPECT_FALSE(block2->isPrimary()); + // Verify that block 1 was assigned to seq1 and is in primary memory + auto cacheBlockIds1 = seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx); + EXPECT_THAT(cacheBlockIds1, ::testing::ElementsAreArray({3, 1})); + auto block1 = blockManager.getBlockById(1, maxAttentionWindow); + EXPECT_TRUE(block1->isPrimary()); - // Fill block + // Fill block with token 105,106,107 for (int i = 101 + tokensPerBlock; i < 100 + 2 * tokensPerBlock; ++i) { llmRequest1->addNewToken(i, beamIdx); kvCacheManager.addToken(requestId1); } - // Verify - auto cacheBlockIds1 = seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx); - EXPECT_THAT(cacheBlockIds1, ::testing::ElementsAreArray({3, 4})); - - // Write one generated token to second sequence. This will prompt block 0 to be offloaded, - // replacing block 2. + // Write one generated token to second sequence. This will prompt block 2 to be offloaded into block 4. llmRequest1->addNewToken(100 + 2 * tokensPerBlock, beamIdx); kvCacheManager.addToken(requestId1); + // seq1: Reserve block 4 for token [108] + + // Verify + cacheBlockIds1 = seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx); + EXPECT_THAT(cacheBlockIds1, ::testing::ElementsAreArray({3, 1, 4})); + + // Verify that block 2 has no parent (because block 1 was evicted) and is in secondary memory. + auto block2 = blockManager.getBlockById(2, maxAttentionWindow); + EXPECT_TRUE(block2->getPrevBlock() == nullptr); + EXPECT_FALSE(block2->isPrimary()); - // Verify that block 2 is free, has no parent - EXPECT_EQ(block2->getPrevBlock(), nullptr); - // Verify that it is block 0 that is in secondary - auto block0 = blockManager.getBlockById(0, maxAttentionWindow); - EXPECT_FALSE(block0->isPrimary()); + // Verify that block 4 is in primary memory + auto block4 = blockManager.getBlockById(4, maxAttentionWindow); + EXPECT_TRUE(block4->isPrimary()); // Cleanup (void) kvCacheManager.removeSequence(requestId1, llmRequest1); @@ -3372,6 +3406,8 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) maxSequenceLength, true, onboardBlocks, CacheType::kSELF, std::nullopt, std::make_unique(1024)); kvCacheManager.allocatePools(false); + // Free queues: + // 35 : 7,6,5,4,3,2,1,0 auto events = getEvents(kvCacheManager); EXPECT_EQ(events.size(), 1); @@ -3385,7 +3421,11 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) auto llmRequest0 = std::make_shared(0, 0, inputTokens0, samplingConfig, true); llmRequest0->setLoraTaskId(42); kvCacheManager.addSequence(0, inputTokens0->size(), beamWidth, llmRequest0); + // Reserve blocks 0,1,2 for prefilled tokens [0,1,2,3] [4,5,6,7] [8,9] + // Free queues + // 35 : 7,6,5,4,3 kvCacheManager.storeContextBlocks(*llmRequest0); + // Store full blocks 0,1 with tokens [0,1,2,3] [4,5,6,7] events = getEvents(kvCacheManager); @@ -3397,7 +3437,11 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) EXPECT_EQ(std::get(events.front().data).blocks[0].cacheLevel, 0); kvCacheManager.addToken(0); llmRequest0->addNewToken(0, 0); - (void) kvCacheManager.removeSequence(0, llmRequest0); + kvCacheManager.removeSequence(0, llmRequest0); + // Store block 2 with tokens [8,9] + // Release blocks 0,1,2 + // Free queues: + // 35 : 0,1,2,7,6,5,4,3 auto newEvents = getEvents(kvCacheManager); EXPECT_EQ(newEvents.size(), 1); @@ -3414,8 +3458,17 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) auto llmRequest1 = std::make_shared(1, 0, inputTokens1, samplingConfig, true); llmRequest1->setLoraTaskId(42); kvCacheManager.addSequence(1, inputTokens1->size(), beamWidth, llmRequest1); + // Reuse blocks 0,1 for tokens [0,1,2,3] [4,5,6,7] + // Add new block 3 for prefill token [8] + // Free queues: + // 35 : 2,7,6,5,4 kvCacheManager.storeContextBlocks(*llmRequest1); - (void) kvCacheManager.removeSequence(1, llmRequest1); + // Blocks 0,1 already stored + kvCacheManager.removeSequence(1, llmRequest1); + // No new blocks stored. 0,1 already stored, 3 contains no reusable state. + // Release blocks 0,1,3 + // Free queues: + // 35 : 0,1,3,2,7,6,5,4 events = getEvents(kvCacheManager); @@ -3424,10 +3477,21 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) auto inputTokens2 = std::make_shared(VecTokens{1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); auto llmRequest2 = std::make_shared(2, 0, inputTokens2, samplingConfig, true); kvCacheManager.addSequence(2, inputTokens2->size(), beamWidth, llmRequest2); + // No reuse possible. Reserve blocks 4,5,6,7 for prefill tokens [1,1,2,3] [4,5,6,7] [8,9,10,11] [12] + // Free queues: + // 35 : 0,1,3,2 auto inputTokens3 = std::make_shared(VecTokens{2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); auto llmRequest3 = std::make_shared(3, 0, inputTokens3, samplingConfig, true); kvCacheManager.addSequence(3, inputTokens3->size(), beamWidth, llmRequest3); + // No reuse possible. Reserve blocks [0,1,3,2] for prefill tokens [2,1,2,3] [3,4,5,6] [7,8,9,10] [11,12] + // Some blocks will be offloaded: + // 2 -> 8 (8 reserved for seq3, 2 now in secondary) + // 3 -> Not offloaded, contains no reusable state + // 1 -> 9 (9 reserved for seq3, 1 now in secondary) + // 0 -> 2 (2 reserved for seq3, 0 now in secondary and 2 evicted) + // Reserved for seq3 = [2,9,3,8] + // No free blocks events = getEvents(kvCacheManager); size_t firstSwapped = std::get(events.front().data).blockHash; @@ -3445,8 +3509,16 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) EXPECT_THAT(std::get(events.front().data).blockHashes, ::testing::ElementsAreArray({firstSwapped})); - (void) kvCacheManager.removeSequence(2, llmRequest2); - (void) kvCacheManager.removeSequence(3, llmRequest3); + kvCacheManager.removeSequence(2, llmRequest2); + // Store blocks 4,5,6,7 + // Release blocks 4,5,6,7 + // Free queues: + // 35 : 4,5,6,7 + kvCacheManager.removeSequence(3, llmRequest3); + // Store blocks 2,9,3,8 + // Release blocks 2,9,3,8 + // Free queues: + // 35 : 2,9,3,8,4,5,6,7 events = getEvents(kvCacheManager); EXPECT_EQ(events.size(), 2); @@ -3462,18 +3534,20 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) auto llmRequest4 = std::make_shared(4, 0, inputTokens4, samplingConfig, true); llmRequest4->setLoraTaskId(42); kvCacheManager.addSequence(4, inputTokens4->size(), beamWidth, llmRequest4); + // Reuse block 0 with tokens [0,1,2,3] + // This will cause block 0 to be onboarded + // Allocate new blocks 7,6 for prefill tokens [1,1,1,1] [0] + // This will cause blocks 7,6 to be offloaded + // Block 3 will be evicted since it contains no reusable state + // Total tally: Onboard 1 block, offload 2 blocks, removbe 1 block. events = getEvents(kvCacheManager); - // Onboard block 0, in replace, offload block 7 - // Offload block 6, and write content of [1,1,1,1] to block 1 - // Upon freeing up block 1, its child block 2, will be removed from the search tree, - // which is a remove event. - // Offload block 5, in replace onboard block 7, and write content of [0] to block 7. - // In total, there are 2 offloads, 1 onboard, 1 removed, total of 4 events. - // FIXME: For better improvement, when block 1 is overwritten, child blocks - // are removed from the search tree and no longer reusable. Therefore these blocks - // should be the first to be called upon when we want a new block. + // As argued above: + // block 0 was onboarded + // blocks 6,7 were offloaded + // block 3 was removed + // In total, expect 4 events. auto onboardedBlocks = 0; auto offloadedBlocks = 0; auto removedBlocks = 0; @@ -3712,8 +3786,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerSWAInvalidateReuseTest) GenerationRequest const& seq1 = kvCacheManager.getSequence(/*requestId=*/1); auto const onlyWindowSize = theOnlyWindowSize(kvCacheManager); - EXPECT_FALSE(blockManager.isSequenceValidForStoreForReuse(seq0.getRequestId(), onlyWindowSize)); - EXPECT_TRUE(blockManager.isSequenceValidForStoreForReuse(seq1.getRequestId(), onlyWindowSize)); EXPECT_NO_THROW(kvCacheManager.removeSequence(seq0.getRequestId(), llmRequest0)); EXPECT_NO_THROW(kvCacheManager.removeSequence(seq1.getRequestId(), llmRequest1)); @@ -4152,16 +4224,21 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamWindowSize) events = getEvents(kvCacheManager); - // Expecting only 1 event, storeContextBlock is not called for sliding window. - EXPECT_EQ(events.size(), 1); + // Expecting 2 events now that storeContext is called for SWA. + EXPECT_EQ(events.size(), 2); - EXPECT_EQ(events.back().windowSize, maxAttentionWindow); - EXPECT_TRUE(std::holds_alternative(events.back().data)); + for (int i = 0; i < events.size(); ++i) + { + EXPECT_EQ(events.front().windowSize, slidingWindow); + EXPECT_TRUE(std::holds_alternative(events.front().data)); + events.pop_front(); + } } TEST_F(KVCacheManagerTest, KVCacheTransferManagerConcurrencyTest) { auto const blockSize = 16384; + auto constexpr windowSize = 2; auto bufferManager = tensorrt_llm::runtime::BufferManager(std::make_shared()); auto transferManager = KVCacheTransferManager(bufferManager); @@ -4179,8 +4256,8 @@ TEST_F(KVCacheManagerTest, KVCacheTransferManagerConcurrencyTest) tr::bufferCast(*pool.secondaryPtr)[i] = 1; } - auto primaryBlock = std::make_shared(0, tensorrt_llm::kernels::KVCacheIndex(0, false)); - auto secondaryBlock = std::make_shared(1, tensorrt_llm::kernels::KVCacheIndex(0, true)); + auto primaryBlock = std::make_shared(0, tensorrt_llm::kernels::KVCacheIndex(0, false), windowSize); + auto secondaryBlock = std::make_shared(1, tensorrt_llm::kernels::KVCacheIndex(0, true), windowSize); transferManager.offload(primaryBlock, secondaryBlock, {pool}); primaryBlock->swapMemoryPoolBlockOffset(secondaryBlock); diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 764bd3937d2..8f86285070f 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -409,6 +409,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): context_batch = scheduled_batch.context_requests generation_batch = scheduled_batch.generation_requests # allocate KV Cache + self.impl.sync_transfer_manager_with_buffer_manager() for req in context_batch: req_beam_width = req.sampling_config.beam_width if 'cp_type' in self.mapping.cp_config and CpType.STAR == self.mapping.cp_config[ @@ -436,6 +437,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): block_ids = self.get_cache_indices(req) self.kv_connector_manager.update_state_after_alloc( req, block_ids) + self.impl.refresh_blocks() for req in generation_batch: self.impl.add_token(req.py_request_id)