diff --git a/README.md b/README.md index a947597f..3d156396 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,9 @@ vllm bench serve --model meta-llama/Llama-3.2-1B-Instruct --request-rate 10 --nu > > When kvcached is enabled, there is NO need to set memory utilization limit (e.g., using `--gpu-memory-utilization`) as kvcached will automatically manage the memory. +> [!NOTE] +> **AMD / ROCm:** on ROCm (HIP) builds, kvcached automatically defaults to the **non-contiguous** KV-cache layout. The contiguous layout (the default on NVIDIA) hands vLLM's ROCm attention backend strided per-layer KV tensors it cannot read correctly, whereas non-contiguous matches the layout the backend expects. You can override with `KVCACHED_CONTIGUOUS_LAYOUT=true|false`, but contiguous is not recommended on ROCm. + If you installed kvcached using its source code, you can also do the following: ```bash diff --git a/benchmarks/bench_vmm/Makefile b/benchmarks/bench_vmm/Makefile index 8aac2faa..49d409d2 100644 --- a/benchmarks/bench_vmm/Makefile +++ b/benchmarks/bench_vmm/Makefile @@ -1,9 +1,24 @@ -NVCC ?= nvcc +INC_DIR = ../../csrc/inc + +# Auto-detect backend: set KVCACHED_BACKEND=hip for AMD, defaults to cuda. +KVCACHED_BACKEND ?= cuda + +ifeq ($(KVCACHED_BACKEND),hip) + CXX := hipcc + CXXFLAGS = -DKVCACHED_USE_HIP -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 + LDLIBS = -lamdhip64 +else + CXX := nvcc + CXXFLAGS = -DKVCACHED_USE_CUDA + LDLIBS = -lcuda +endif + +CXXFLAGS += -O2 -g -std=c++17 -I$(INC_DIR) all: bench_vmm.bin bench_vmm.bin: bench_vmm.cpp - $(NVCC) $^ -o $@ -O2 -g -lcuda -std=c++17 + $(CXX) $^ -o $@ $(CXXFLAGS) $(LDLIBS) clean: - $(RM) bench_vmm.bin \ No newline at end of file + $(RM) bench_vmm.bin diff --git a/benchmarks/bench_vmm/README.md b/benchmarks/bench_vmm/README.md index 375496c8..4d8d5cf9 100644 --- a/benchmarks/bench_vmm/README.md +++ b/benchmarks/bench_vmm/README.md @@ -1,28 +1,34 @@ # VMM Benchmark -This benchmark measures the latency of various CUDA Virtual Memory Management (VMM) operations. +This benchmark measures the latency of various GPU Virtual Memory Management (VMM) operations on both NVIDIA (CUDA) and AMD (ROCm/HIP) GPUs. ## Description -The tool benchmarks the following CUDA driver API calls: -- `cuMemAddressReserve`: Reserving a virtual address range. -- `cuMemCreate`: Allocating physical memory. -- `cuMemMap`: Mapping physical memory to a virtual address. -- `cuMemSetAccess`: Setting access permissions for a mapped region. -- `cuMemUnmap`: Unmapping physical memory. +The tool benchmarks the following VMM API calls: +- `address_reserve`: Reserving a virtual address range. +- `mem_create`: Allocating physical memory. +- `mem_map`: Mapping physical memory to a virtual address. +- `set_access`: Setting access permissions for a mapped region. +- `mem_unmap`: Unmapping physical memory. It uses multiple CPU threads to issue these commands in parallel and reports latency statistics (average, p50, p90, p99, and max). ## Building the Benchmark -You need a CUDA-enabled GPU and the CUDA Toolkit installed. +You need a GPU with VMM support and the corresponding toolkit installed (CUDA Toolkit or ROCm). -Compile the benchmark`: +For NVIDIA GPUs (default): ```bash make ``` +For AMD GPUs: + +```bash +make KVCACHED_BACKEND=hip +``` + ## Running the Benchmark Execute the compiled binary: @@ -36,17 +42,18 @@ The benchmark parameters (number of threads, page size, etc.) are defined as `co ## Sample Output on A100 ``` +Backend: CUDA Total Free Memory: 84.5442GB -====== cuMemMap ElemSz=1 ====== +====== VMM Benchmark ====== -cuMemAddressReserve (8GB) latency: 19 us +address_reserve (8GB) latency: 19 us Benchmarking with 1 threads and 4096 pages of size 2MB. --------------------------------------------------------------------------- Operation avg (us) p50 (us) p90 (us) p99 (us) max (us) --------------------------------------------------------------------------- -cuMemCreate 193.32 195.00 339.00 381.00 493.00 -cuMemMap 1.45 0.00 4.00 5.00 105.00 -cuMemSetAccess 35.99 35.00 42.00 54.00 169.00 -cuMemUnmap 25.63 25.00 27.00 39.00 126.00 +mem_create 193.32 195.00 339.00 381.00 493.00 +mem_map 1.45 0.00 4.00 5.00 105.00 +set_access 35.99 35.00 42.00 54.00 169.00 +mem_unmap 25.63 25.00 27.00 39.00 126.00 ``` diff --git a/benchmarks/bench_vmm/bench_vmm.cpp b/benchmarks/bench_vmm/bench_vmm.cpp index 5357f478..81ea3da7 100644 --- a/benchmarks/bench_vmm/bench_vmm.cpp +++ b/benchmarks/bench_vmm/bench_vmm.cpp @@ -9,10 +9,10 @@ #include #include -#include -#include +#include "gpu_utils.hpp" +#include "gpu_vmm.hpp" -#include "cuda_utils.hpp" +namespace vmm = kvcached::gpu_vmm; static constexpr int kNumThds = 1; static constexpr size_t kPageSize = 2ul << 20; // MB @@ -23,30 +23,27 @@ void print_header(); void print_stats(const std::string &op_name, const std::vector latencies[kNumThds]); -int init_cuda() { - size_t free; - typedef unsigned char ElemType; - CUcontext ctx; - CUdevice dev; - int supportsVMM = 0; +int init_gpu() { + int supports_vmm = 0; - CHECK_RT(cudaFree(0)); + CHECK_GPU(vmm::initialize_runtime()); + CHECK_GPU(vmm::set_device(0)); + int dev_idx = vmm::current_device(); - CHECK_DRV(cuInit(0)); - CHECK_DRV(cuDevicePrimaryCtxRetain(&ctx, 0)); - CHECK_DRV(cuCtxSetCurrent(ctx)); - CHECK_DRV(cuCtxGetDevice(&dev)); + size_t free_mem = 0, total_mem = 0; +#if defined(KVCACHED_USE_HIP) + CHECK_GPU(hipMemGetInfo(&free_mem, &total_mem)); +#elif defined(KVCACHED_USE_CUDA) + CHECK_GPU(cudaMemGetInfo(&free_mem, &total_mem)); +#endif - CHECK_DRV(cuMemGetInfo(&free, NULL)); - std::cout << "Total Free Memory: " << (float)free / std::giga::num << "GB" + std::cout << "Backend: " << vmm::backend_name() << std::endl; + std::cout << "Total Free Memory: " << (float)free_mem / std::giga::num << "GB" << std::endl; - CHECK_DRV(cuDeviceGetAttribute( - &supportsVMM, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, - dev)); - if (supportsVMM) { - std::cout << "====== cuMemMap ElemSz=" << sizeof(ElemType) - << " ======" << std::endl; + CHECK_DRV(vmm::get_vmm_support(&supports_vmm, dev_idx)); + if (supports_vmm) { + std::cout << "====== VMM Benchmark ======" << std::endl; } else { std::cout << "VMM not supported" << std::endl; } @@ -54,29 +51,20 @@ int init_cuda() { return 0; } -CUdeviceptr alloc_virtual(size_t size) { - CUdeviceptr addr; - CHECK_DRV(cuMemAddressReserve(&addr, size, kPageSize, 0, 0)); +void *alloc_virtual(size_t size) { + void *addr = nullptr; + CHECK_DRV(vmm::address_reserve(&addr, size, kPageSize)); return addr; } -int bench_physical_alloc(std::vector &handles) { +int bench_physical_alloc(std::vector &handles) { std::vector thds; std::vector latencies[kNumThds]; handles.resize(kNumPages); - CUdevice dev; - CHECK_DRV(cuCtxGetDevice(&dev)); - - CUmemAllocationProp prop = { - .type = CU_MEM_ALLOCATION_TYPE_PINNED, - .location = - { - .type = CU_MEM_LOCATION_TYPE_DEVICE, - .id = dev, - }, - }; + int dev_idx = vmm::current_device(); + auto prop = vmm::make_pinned_device_allocation_prop(dev_idx); for (int i = 0; i < kNumThds; i++) { thds.emplace_back([&, tid = i]() { @@ -84,7 +72,7 @@ int bench_physical_alloc(std::vector &handles) { auto end_page = kNumPages / kNumThds * (tid + 1); for (size_t page_idx = stt_page; page_idx < end_page; page_idx++) { auto stt = std::chrono::high_resolution_clock::now(); - CHECK_DRV(cuMemCreate(&handles[page_idx], kPageSize, &prop, 0)); + CHECK_DRV(vmm::mem_create(&handles[page_idx], kPageSize, &prop)); auto end = std::chrono::high_resolution_clock::now(); latencies[tid].push_back( std::chrono::duration_cast(end - stt) @@ -97,7 +85,7 @@ int bench_physical_alloc(std::vector &handles) { thd.join(); } - print_stats("cuMemCreate", latencies); + print_stats("mem_create", latencies); return 0; } @@ -152,10 +140,10 @@ void print_stats(const std::string &op_name, << std::setw(15) << max << std::endl; } -int bench_mmap(CUdeviceptr addr, - std::vector &handles) { +int bench_mmap(void *addr, std::vector &handles) { std::vector thds; std::vector latencies[kNumThds]; + char *base = static_cast(addr); for (int i = 0; i < kNumThds; i++) { thds.emplace_back([&, tid = i]() { @@ -163,7 +151,7 @@ int bench_mmap(CUdeviceptr addr, auto end = kNumPages / kNumThds * (tid + 1); for (size_t i = stt; i < end; i++) { auto stt = std::chrono::high_resolution_clock::now(); - CHECK_DRV(cuMemMap(addr + i * kPageSize, kPageSize, 0, handles[i], 0)); + CHECK_DRV(vmm::mem_map(base + i * kPageSize, kPageSize, 0, handles[i])); auto end = std::chrono::high_resolution_clock::now(); latencies[tid].push_back( std::chrono::duration_cast(end - stt) @@ -176,25 +164,18 @@ int bench_mmap(CUdeviceptr addr, thd.join(); } - print_stats("cuMemMap", latencies); + print_stats("mem_map", latencies); return 0; } -int bench_setaccess(CUdeviceptr addr) { +int bench_setaccess(void *addr) { std::vector thds; std::vector latencies[kNumThds]; - CUdevice dev; - - CHECK_DRV(cuCtxGetDevice(&dev)); - CUmemAccessDesc accessDesc{ - .location = - { - .type = CU_MEM_LOCATION_TYPE_DEVICE, - .id = dev, - }, - .flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE, - }; + char *base = static_cast(addr); + + int dev_idx = vmm::current_device(); + auto access_desc = vmm::make_device_rw_access_desc(dev_idx); for (int i = 0; i < kNumThds; i++) { thds.emplace_back([&, tid = i]() { @@ -203,7 +184,7 @@ int bench_setaccess(CUdeviceptr addr) { for (size_t i = stt; i < end; i++) { auto stt = std::chrono::high_resolution_clock::now(); CHECK_DRV( - cuMemSetAccess(addr + i * kPageSize, kPageSize, &accessDesc, 1)); + vmm::set_access(base + i * kPageSize, kPageSize, &access_desc, 1)); auto end = std::chrono::high_resolution_clock::now(); latencies[tid].push_back( std::chrono::duration_cast(end - stt) @@ -216,14 +197,15 @@ int bench_setaccess(CUdeviceptr addr) { thd.join(); } - print_stats("cuMemSetAccess", latencies); + print_stats("set_access", latencies); return 0; } -int bench_munmap(CUdeviceptr addr) { +int bench_munmap(void *addr) { std::vector thds; std::vector latencies[kNumThds]; + char *base = static_cast(addr); for (int i = 0; i < kNumThds; i++) { thds.emplace_back([&, tid = i]() { @@ -231,7 +213,7 @@ int bench_munmap(CUdeviceptr addr) { auto end = kNumPages / kNumThds * (tid + 1); for (size_t i = stt; i < end; i++) { auto stt = std::chrono::high_resolution_clock::now(); - CHECK_DRV(cuMemUnmap(addr + i * kPageSize, kPageSize)); + CHECK_DRV(vmm::mem_unmap(base + i * kPageSize, kPageSize)); auto end = std::chrono::high_resolution_clock::now(); latencies[tid].push_back( std::chrono::duration_cast(end - stt) @@ -244,34 +226,32 @@ int bench_munmap(CUdeviceptr addr) { thd.join(); } - print_stats("cuMemUnmap", latencies); + print_stats("mem_unmap", latencies); return 0; } -void free_physical(std::vector &handles) { +void free_physical(std::vector &handles) { for (const auto &handle : handles) { - CHECK_DRV(cuMemRelease(handle)); + CHECK_DRV(vmm::mem_release(handle)); } } -void free_virtual(CUdeviceptr addr) { - CHECK_DRV(cuMemAddressFree(addr, kMemSize)); -} +void free_virtual(void *addr) { CHECK_DRV(vmm::address_free(addr, kMemSize)); } int main() { - init_cuda(); + init_gpu(); auto stt = std::chrono::high_resolution_clock::now(); - CUdeviceptr addr = alloc_virtual(kMemSize); + void *addr = alloc_virtual(kMemSize); auto end = std::chrono::high_resolution_clock::now(); auto lat = std::chrono::duration_cast(end - stt).count(); - std::cout << "\ncuMemAddressReserve (" << (kMemSize >> 30) + std::cout << "\naddress_reserve (" << (kMemSize >> 30) << "GB) latency: " << lat << " us\n" << std::endl; - std::vector handles; + std::vector handles; print_header(); bench_physical_alloc(handles); diff --git a/benchmarks/bench_vmm/cuda_utils.hpp b/benchmarks/bench_vmm/cuda_utils.hpp deleted file mode 100644 index 5cb6591a..00000000 --- a/benchmarks/bench_vmm/cuda_utils.hpp +++ /dev/null @@ -1,77 +0,0 @@ -// SPDX-FileCopyrightText: Copyright contributors to the kvcached project -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include - -#include -#include - -#define LOGE(format, ...) \ - fprintf(stdout, "L%d:" format "\n", __LINE__, ##__VA_ARGS__); \ - fflush(stdout); - -#define ASSERT(cond, ...) \ - { \ - if (!(cond)) { \ - LOGE(__VA_ARGS__); \ - assert(0); \ - } \ - } - -#define WARN(cond, ...) \ - { \ - if (!(cond)) { \ - LOGE(__VA_ARGS__); \ - } \ - } - -#define DRV_CALL(call) \ - { \ - CUresult result = (call); \ - if (CUDA_SUCCESS != result) { \ - const char *errMsg; \ - cuGetErrorString(result, &errMsg); \ - ASSERT(0, "Error when exec " #call " %s-%d code:%d err:%s", \ - __FUNCTION__, __LINE__, result, errMsg); \ - } \ - } - -#define DRV_CALL_RET(call, status_val) \ - { \ - CUresult result = (call); \ - if (CUDA_SUCCESS != result) { \ - const char *errMsg; \ - cuGetErrorString(result, &errMsg); \ - WARN(0, "Error when exec " #call " %s-%d code:%d err:%s", __FUNCTION__, \ - __LINE__, result, errMsg); \ - } \ - status_val = result; \ - } - -static inline void checkRtError(cudaError_t res, const char *tok, - const char *file, unsigned line) { - if (res != cudaSuccess) { - std::cerr << file << ':' << line << ' ' << tok - << " failed in CUDA runtime (" << (unsigned)res - << "): " << cudaGetErrorString(res) << std::endl; - abort(); - } -} - -#define CHECK_RT(x) checkRtError(x, #x, __FILE__, __LINE__) - -static inline void checkDrvError(CUresult res, const char *tok, - const char *file, unsigned line) { - if (res != CUDA_SUCCESS) { - const char *errStr = nullptr; - (void)cuGetErrorString(res, &errStr); - std::cerr << file << ':' << line << ' ' << tok << " failed in CUDA driver (" - << (unsigned)res << "): " << errStr << std::endl; - abort(); - } -} - -#define CHECK_DRV(x) checkDrvError(x, #x, __FILE__, __LINE__) diff --git a/csrc/allocator.cpp b/csrc/allocator.cpp index 9adb2715..edb4f34c 100644 --- a/csrc/allocator.cpp +++ b/csrc/allocator.cpp @@ -3,15 +3,13 @@ #include #include -#include #include #include "allocator.hpp" #include "constants.hpp" -#include "cuda_utils.hpp" #include "ftensor.hpp" +#include "gpu_utils.hpp" #include "page.hpp" -#include "torch_utils.hpp" namespace kvcached { // Global configurable page size @@ -20,14 +18,24 @@ size_t kPageSize = 2 * 1024 * 1024; // Default 2MB std::unordered_map> FTensorAllocator::g_allocators_; std::mutex FTensorAllocator::g_allocator_mutex_; -torch::Device FTensorAllocator::g_device_(torch::kCPU); +c10::Device FTensorAllocator::g_device_(c10::kCPU); bool FTensorAllocator::g_contiguous_layout_ = false; -static inline std::shared_ptr make_shared_page(const torch::Device &dev, +static inline std::shared_ptr make_shared_page(const c10::Device &dev, page_id_t page_id, size_t page_size = 0) { + auto resolve_device_index = [](const c10::Device &device) -> int { + if (device.index() >= 0) { + return device.index(); + } + return gpu_vmm::current_device(); + }; + + // is_cuda() returns true for both NVIDIA (CUDA) and AMD (HIP/ROCm) devices, + // because PyTorch's ROCm build masquerades HIP devices as CUDA. if (dev.is_cuda()) { - return std::make_shared(page_id, dev.index(), page_size); + return std::make_shared(page_id, resolve_device_index(dev), + page_size); } else if (dev.is_cpu()) { return std::make_shared(page_id, page_size); } @@ -35,7 +43,7 @@ static inline std::shared_ptr make_shared_page(const torch::Device &dev, return nullptr; } -static inline size_t get_v_base_offset(const torch::Tensor &tensor) { +static inline size_t get_v_base_offset(const at::Tensor &tensor) { size_t num_eles = tensor.numel() * tensor.element_size(); ASSERT(num_eles % (2 * kPageSize) == 0, "Invalid tensor size: %zu, must be a multiple of 2 * page size %zu", @@ -43,12 +51,12 @@ static inline size_t get_v_base_offset(const torch::Tensor &tensor) { return num_eles / 2; } -FTensorAllocator::FTensorAllocator(const torch::Device &device, +FTensorAllocator::FTensorAllocator(const c10::Device &device, bool contiguous_layout) : dev_(device), num_layers_(0), contiguous_layout_(contiguous_layout), unified_pool_(false), kv_tensor_size_per_layer_(0) { if (dev_.is_cuda()) { - init_cuda_(); + init_gpu_(); } } @@ -83,7 +91,7 @@ void FTensorAllocator::init(const std::string &dev_str, size_t page_size, kPageSize = page_size; } - auto device = torch::Device(dev_str); + auto device = c10::Device(dev_str); g_device_ = device; g_contiguous_layout_ = contiguous_layout; g_allocators_[0] = @@ -110,8 +118,8 @@ void FTensorAllocator::shutdown() { g_allocators_.clear(); } -std::vector FTensorAllocator::create_kv_tensors( - size_t size, torch::Dtype dtype, const std::string &dev_str, +std::vector FTensorAllocator::create_kv_tensors( + size_t size, c10::ScalarType dtype, const std::string &dev_str, int64_t num_layers, int64_t num_kv_buffers, bool unified_pool) { std::lock_guard lock(mtx_); @@ -254,10 +262,10 @@ std::string FTensorAllocator::get_anon_tensor_name_() { return std::string(prefix) + std::to_string(counter++); } -std::vector FTensorAllocator::create_kv_tensors_per_layer_( - std::string_view prefix, size_t size, torch::Dtype dtype, +std::vector FTensorAllocator::create_kv_tensors_per_layer_( + std::string_view prefix, size_t size, c10::ScalarType dtype, const std::string &dev_str, int64_t num_layers) { - std::vector ftensors; + std::vector ftensors; for (int64_t i = 0; i < num_layers; i++) { auto name = std::string(prefix) + std::to_string(i); auto tensor = create_ftensor_(size, dtype, dev_str, name); @@ -266,8 +274,8 @@ std::vector FTensorAllocator::create_kv_tensors_per_layer_( return ftensors; } -std::vector FTensorAllocator::create_kv_tensors_contiguous_( - size_t size, torch::Dtype dtype, const std::string &dev_str, +std::vector FTensorAllocator::create_kv_tensors_contiguous_( + size_t size, c10::ScalarType dtype, const std::string &dev_str, int64_t num_layers, size_t compound_page_size) { // In contiguous layout, Python passes per-layer size, and we multiply by // num_layers to get total size @@ -285,16 +293,16 @@ std::vector FTensorAllocator::create_kv_tensors_contiguous_( } /** this function is not thread-safe */ -torch::Tensor FTensorAllocator::create_ftensor_(size_t size, torch::Dtype dtype, - const std::string &dev_str, - std::string name) { +at::Tensor FTensorAllocator::create_ftensor_(size_t size, c10::ScalarType dtype, + const std::string &dev_str, + std::string name) { if (name.empty()) name = get_anon_tensor_name_(); if (ftensors_.find(name) != ftensors_.end()) { auto tensor = ftensors_[name].get()->get_tensor(); assert(tensor.numel() * tensor.element_size() == size); - assert(tensor.device() == torch::Device(dev_str)); + assert(tensor.device() == c10::Device(dev_str)); return tensor; } @@ -305,7 +313,7 @@ torch::Tensor FTensorAllocator::create_ftensor_(size_t size, torch::Dtype dtype, } /** this function is not thread-safe */ -void FTensorAllocator::free_ftensor_(torch::Tensor &ftensor) { +void FTensorAllocator::free_ftensor_(at::Tensor &ftensor) { auto name = ftensor.name(); if (ftensors_.find(name) == ftensors_.end()) { return; @@ -313,36 +321,25 @@ void FTensorAllocator::free_ftensor_(torch::Tensor &ftensor) { ftensors_.erase(name); } -void FTensorAllocator::init_cuda_() { - CHECK_RT(cudaFree(0)); +void FTensorAllocator::init_gpu_() { + CHECK_GPU(gpu_vmm::initialize_runtime()); - CUdevice dev; - CHECK_DRV(cuCtxGetDevice(&dev)); + int dev_idx = dev_.index() >= 0 ? dev_.index() : gpu_vmm::current_device(); + CHECK_GPU(gpu_vmm::set_device(dev_idx)); - int supportsVMM = 0; - CHECK_DRV(cuDeviceGetAttribute( - &supportsVMM, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, - dev)); - // LOGE("Supports VMM: %d", supportsVMM); - - CUcontext context; - CHECK_DRV(cuCtxGetCurrent(&context)); - - CUmemAllocationProp prop{ - .type = CU_MEM_ALLOCATION_TYPE_PINNED, - .location = - { - .type = CU_MEM_LOCATION_TYPE_DEVICE, - .id = dev, - }, - }; + int supports_vmm = 0; + CHECK_GPU(gpu_vmm::get_vmm_support(&supports_vmm, dev_idx)); + ASSERT(supports_vmm != 0, + "VMM is not supported on %s device %d. kvcached requires GPU VMM " + "support.", + gpu_vmm::backend_name(), dev_idx); + auto prop = gpu_vmm::make_pinned_device_allocation_prop(dev_idx); size_t chunk_sz = 0; - CHECK_DRV(cuMemGetAllocationGranularity(&chunk_sz, &prop, - CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + CHECK_GPU(gpu_vmm::get_allocation_granularity(&chunk_sz, &prop)); ASSERT(kPageSize % chunk_sz == 0, - "Invalid page size: %lu must be a multiple of CUDA granularity %lu\n", - kPageSize, chunk_sz); + "Invalid page size: %lu must be a multiple of %s granularity %lu\n", + kPageSize, gpu_vmm::backend_name(), chunk_sz); } } // namespace kvcached diff --git a/csrc/ftensor.cpp b/csrc/ftensor.cpp index c2eae1df..95ad9f23 100644 --- a/csrc/ftensor.cpp +++ b/csrc/ftensor.cpp @@ -4,16 +4,26 @@ #include #include +#include +#include + #include "constants.hpp" -#include "cuda_utils.hpp" #include "ftensor.hpp" +#include "gpu_utils.hpp" #include "page.hpp" namespace kvcached { static std::atomic g_vaddr_allocated_offset = 0; -static inline generic_ptr_t alloc_virtual_mem(const torch::Device &dev, +static inline int resolve_device_index(const c10::Device &dev) { + if (dev.index() >= 0) { + return dev.index(); + } + return gpu_vmm::current_device(); +} + +static inline generic_ptr_t alloc_virtual_mem(const c10::Device &dev, size_t size) { size_t alignment_2mb = 2 * 1024 * 1024; ASSERT(size % alignment_2mb == 0, @@ -21,9 +31,12 @@ static inline generic_ptr_t alloc_virtual_mem(const torch::Device &dev, generic_ptr_t vaddr; size_t offset = g_vaddr_allocated_offset.fetch_add(size); + // is_cuda() returns true for both NVIDIA (CUDA) and AMD (HIP/ROCm) devices, + // because PyTorch's ROCm build masquerades HIP devices as CUDA. if (dev.is_cuda()) { - CHECK_DRV(cuMemAddressReserve(reinterpret_cast(&vaddr), size, - alignment_2mb, kStartAddr + offset, 0ULL)); + CHECK_GPU(gpu_vmm::address_reserve( + reinterpret_cast(&vaddr), size, alignment_2mb, + reinterpret_cast(kStartAddr + offset))); } else { vaddr = mmap(reinterpret_cast(kStartAddr + offset), size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); @@ -33,11 +46,12 @@ static inline generic_ptr_t alloc_virtual_mem(const torch::Device &dev, return vaddr; } -static inline std::unique_ptr make_unique_page(const torch::Device &dev, +static inline std::unique_ptr make_unique_page(const c10::Device &dev, page_id_t page_id, size_t page_size = 0) { if (dev.is_cuda()) { - return std::make_unique(page_id, dev.index(), page_size); + return std::make_unique(page_id, resolve_device_index(dev), + page_size); } else if (dev.is_cpu()) { return std::make_unique(page_id, page_size); } @@ -45,8 +59,8 @@ static inline std::unique_ptr make_unique_page(const torch::Device &dev, return nullptr; } -FTensor::FTensor(const std::string &name, size_t size, torch::Dtype dtype, - torch::Device dev, std::shared_ptr zero_page, +FTensor::FTensor(const std::string &name, size_t size, c10::ScalarType dtype, + c10::Device dev, std::shared_ptr zero_page, size_t page_size) : name_(name), vaddr_(nullptr), size_(size), page_size_(page_size > 0 ? page_size : kPageSize), dtype_(dtype), @@ -54,28 +68,29 @@ FTensor::FTensor(const std::string &name, size_t size, torch::Dtype dtype, vaddr_ = alloc_virtual_mem(dev_, size_); init_with_zero_(); - auto num_elems = static_cast(size / torch::elementSize(dtype_)); + auto num_elems = static_cast(size / c10::elementSize(dtype_)); auto options = - torch::TensorOptions().dtype(dtype_).device(dev_).requires_grad(false); + at::TensorOptions().dtype(dtype_).device(dev_).requires_grad(false); tensor_ = - torch::from_blob(reinterpret_cast(vaddr_), {num_elems}, options); + at::from_blob(reinterpret_cast(vaddr_), {num_elems}, options); } FTensor::~FTensor() { if (vaddr_) { - CUresult res = cuMemUnmap(reinterpret_cast(vaddr_), size_); - if (res != CUDA_SUCCESS) { - const char *err = nullptr; - (void)cuGetErrorString(res, &err); - LOGGER(ERROR, "cuMemUnmap during FTensor cleanup failed: %s", - err ? err : "unknown"); - } - res = cuMemAddressFree(reinterpret_cast(vaddr_), size_); - if (res != CUDA_SUCCESS) { - const char *err = nullptr; - (void)cuGetErrorString(res, &err); - LOGGER(ERROR, "cuMemAddressFree during FTensor cleanup failed: %s", - err ? err : "unknown"); + if (dev_.is_cuda()) { + // Tolerate stale VMM mappings during teardown: log, do not abort. + auto res = gpu_vmm::mem_unmap(vaddr_, size_); + if (!gpu_vmm::is_success(res)) { + LOGGER(ERROR, "mem_unmap during FTensor cleanup failed: %s", + gpu_vmm::error_string(res)); + } + res = gpu_vmm::address_free(vaddr_, size_); + if (!gpu_vmm::is_success(res)) { + LOGGER(ERROR, "address_free during FTensor cleanup failed: %s", + gpu_vmm::error_string(res)); + } + } else if (dev_.is_cpu()) { + ASSERT(munmap(vaddr_, size_) == 0, "munmap failed."); } } mapping_.clear(); // Free physical page handles after their mappings are gone. @@ -93,7 +108,9 @@ bool FTensor::map(offset_t offset) { auto vaddr = reinterpret_cast( reinterpret_cast(vaddr_) + offset); - CHECK_DRV(cuMemUnmap(reinterpret_cast(vaddr), page_size_)); + if (dev_.is_cuda()) { + CHECK_GPU(gpu_vmm::mem_unmap(vaddr, page_size_)); + } mapping_[page_id] = make_unique_page(dev_, page_id, page_size_); mapping_[page_id]->map(vaddr); @@ -111,7 +128,9 @@ bool FTensor::unmap(offset_t offset) { auto vaddr = reinterpret_cast( reinterpret_cast(vaddr_) + offset); - CHECK_DRV(cuMemUnmap(reinterpret_cast(vaddr), page_size_)); + if (dev_.is_cuda()) { + CHECK_GPU(gpu_vmm::mem_unmap(vaddr, page_size_)); + } // Map the zero page instead to ensure memory integrity. map_(zero_page_.get(), offset); @@ -129,16 +148,12 @@ bool FTensor::map_(Page *page, offset_t offset, bool set_access) { } bool FTensor::set_access_(generic_ptr_t addr, size_t size) { - CUmemAccessDesc accessDesc_{ - .location = - { - .type = CU_MEM_LOCATION_TYPE_DEVICE, - .id = dev_.index(), - }, - .flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE, - }; - CHECK_DRV(cuMemSetAccess(reinterpret_cast(addr), size, - &accessDesc_, 1)); + if (!dev_.is_cuda()) { + return true; + } + auto access_desc = + gpu_vmm::make_device_rw_access_desc(resolve_device_index(dev_)); + CHECK_GPU(gpu_vmm::set_access(addr, size, &access_desc, 1)); return true; } diff --git a/csrc/inc/allocator.hpp b/csrc/inc/allocator.hpp index 182837cd..c989c52f 100644 --- a/csrc/inc/allocator.hpp +++ b/csrc/inc/allocator.hpp @@ -10,8 +10,9 @@ #include #include -#include -#include +#include +#include +#include #include "constants.hpp" #include "ftensor.hpp" @@ -21,15 +22,15 @@ namespace kvcached { class FTensorAllocator { public: - FTensorAllocator(const torch::Device &device, bool contiguous_layout); + FTensorAllocator(const c10::Device &device, bool contiguous_layout); ~FTensorAllocator(); // KV cache interfaces. - std::vector create_kv_tensors(size_t size, torch::Dtype dtype, - const std::string &dev_str, - int64_t num_layers, - int64_t num_kv_buffers = 2, - bool unified_pool = false); + std::vector create_kv_tensors(size_t size, c10::ScalarType dtype, + const std::string &dev_str, + int64_t num_layers, + int64_t num_kv_buffers = 2, + bool unified_pool = false); bool kv_tensors_created(); bool map_to_kv_tensors(const std::vector &offsets); bool unmap_from_kv_tensors(const std::vector &offsets); @@ -47,31 +48,30 @@ class FTensorAllocator { private: // Raw FTensor interfaces. Must call with lock. static std::string get_anon_tensor_name_(); - std::vector + std::vector create_kv_tensors_per_layer_(std::string_view prefix, size_t size, - torch::Dtype dtype, const std::string &dev_str, - int64_t num_layers); - std::vector - create_kv_tensors_contiguous_(size_t size, torch::Dtype dtype, + c10::ScalarType dtype, + const std::string &dev_str, int64_t num_layers); + std::vector + create_kv_tensors_contiguous_(size_t size, c10::ScalarType dtype, const std::string &dev_str, int64_t num_layers, size_t compound_page_size); - torch::Tensor create_ftensor_(size_t size, torch::Dtype dtype, - const std::string &dev_str, - std::string name = ""); - void free_ftensor_(torch::Tensor &ftensor); + at::Tensor create_ftensor_(size_t size, c10::ScalarType dtype, + const std::string &dev_str, std::string name = ""); + void free_ftensor_(at::Tensor &ftensor); - // CUDA util functions. - void init_cuda_(); + // GPU VMM util functions. + void init_gpu_(); // Multiton: one allocator per group_id. static std::unordered_map> g_allocators_; static std::mutex g_allocator_mutex_; // Device and layout from init(), used to create new group allocators. - static torch::Device g_device_; + static c10::Device g_device_; static bool g_contiguous_layout_; - torch::Device dev_; + c10::Device dev_; int64_t num_layers_; bool contiguous_layout_; diff --git a/csrc/inc/ftensor.hpp b/csrc/inc/ftensor.hpp index 0e10b24c..b56d19b6 100644 --- a/csrc/inc/ftensor.hpp +++ b/csrc/inc/ftensor.hpp @@ -6,7 +6,9 @@ #include #include -#include +#include +#include +#include #include "constants.hpp" #include "page.hpp" @@ -16,14 +18,14 @@ namespace kvcached { /* NOTE: FTensorAllocator is thread-safe but FTensor is not. */ class FTensor { public: - FTensor(const std::string &name, size_t size, torch::Dtype dtype, - torch::Device dev, std::shared_ptr zero_page, + FTensor(const std::string &name, size_t size, c10::ScalarType dtype, + c10::Device dev, std::shared_ptr zero_page, size_t page_size = 0); ~FTensor(); bool map(offset_t offset); bool unmap(offset_t offset); - inline torch::Tensor get_tensor() noexcept { return tensor_; } + inline at::Tensor get_tensor() noexcept { return tensor_; } private: bool map_(Page *page, offset_t offset, bool set_access = true); @@ -34,11 +36,11 @@ class FTensor { generic_ptr_t vaddr_; size_t size_; size_t page_size_; - torch::Dtype dtype_; - torch::Device dev_; + c10::ScalarType dtype_; + c10::Device dev_; std::shared_ptr zero_page_; - torch::Tensor tensor_; + at::Tensor tensor_; std::unordered_map> mapping_; }; diff --git a/csrc/inc/cuda_utils.hpp b/csrc/inc/gpu_utils.hpp similarity index 72% rename from csrc/inc/cuda_utils.hpp rename to csrc/inc/gpu_utils.hpp index 5a984d57..655f5096 100644 --- a/csrc/inc/cuda_utils.hpp +++ b/csrc/inc/gpu_utils.hpp @@ -3,15 +3,16 @@ #pragma once -#include -#include - #include +#include +#include #include #include #include #include +#include "gpu_vmm.hpp" + typedef enum { FATAL = 0, ERROR = 1, @@ -94,54 +95,23 @@ static inline pid_t gettid(void) { return (pid_t)syscall(SYS_gettid); } #define WARN(cond, ...) \ { \ if (!(cond)) { \ - LOGE(__VA_ARGS__); \ + LOGW(__VA_ARGS__); \ } \ } -#define DRV_CALL(call) \ - { \ - CUresult result = (call); \ - if (CUDA_SUCCESS != result) { \ - const char *errMsg; \ - cuGetErrorString(result, &errMsg); \ - ASSERT(0, "Error when exec " #call " %s-%d code:%d err:%s", \ - __FUNCTION__, __LINE__, result, errMsg); \ - } \ - } +#define DRV_CALL(call) CHECK_GPU(call) #define DRV_CALL_RET(call, status_val) \ { \ - CUresult result = (call); \ - if (CUDA_SUCCESS != result) { \ - const char *errMsg; \ - cuGetErrorString(result, &errMsg); \ + auto result = (call); \ + if (!kvcached::gpu_vmm::is_success(result)) { \ WARN(0, "Error when exec " #call " %s-%d code:%d err:%s", __FUNCTION__, \ - __LINE__, result, errMsg); \ + __LINE__, static_cast(result), \ + kvcached::gpu_vmm::error_string(result)); \ } \ status_val = result; \ } -static inline void checkRtError(cudaError_t res, const char *tok, - const char *file, unsigned line) { - if (res != cudaSuccess) { - std::cerr << file << ':' << line << ' ' << tok - << " failed in CUDA runtime (" << (unsigned)res - << "): " << cudaGetErrorString(res) << std::endl; - abort(); - } -} - -#define CHECK_RT(x) checkRtError(x, #x, __FILE__, __LINE__) - -static inline void checkDrvError(CUresult res, const char *tok, - const char *file, unsigned line) { - if (res != CUDA_SUCCESS) { - const char *errStr = nullptr; - (void)cuGetErrorString(res, &errStr); - std::cerr << file << ':' << line << ' ' << tok << " failed in CUDA driver (" - << (unsigned)res << "): " << errStr << std::endl; - abort(); - } -} - -#define CHECK_DRV(x) checkDrvError(x, #x, __FILE__, __LINE__) +#define CHECK_GPU(x) kvcached::gpu_vmm::check((x), #x, __FILE__, __LINE__) +#define CHECK_RT(x) CHECK_GPU(x) +#define CHECK_DRV(x) CHECK_GPU(x) diff --git a/csrc/inc/gpu_vmm.hpp b/csrc/inc/gpu_vmm.hpp new file mode 100644 index 00000000..86c28c96 --- /dev/null +++ b/csrc/inc/gpu_vmm.hpp @@ -0,0 +1,255 @@ +// SPDX-FileCopyrightText: Copyright contributors to the kvcached project +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#if defined(KVCACHED_USE_HIP) +#include +#elif defined(KVCACHED_USE_CUDA) +#include +#include +#else +#error "kvcached requires one of KVCACHED_USE_HIP or KVCACHED_USE_CUDA." +#endif + +namespace kvcached { +namespace gpu_vmm { + +#if defined(KVCACHED_USE_HIP) + +using status_t = hipError_t; +using allocation_handle_t = hipMemGenericAllocationHandle_t; +using allocation_prop_t = hipMemAllocationProp; +using access_desc_t = hipMemAccessDesc; + +inline const char *backend_name() { return "HIP"; } + +inline const char *error_string(status_t status) { + return hipGetErrorString(status); +} + +inline bool is_success(status_t status) { return status == hipSuccess; } + +inline void check(status_t status, const char *tok, const char *file, + unsigned line) { + if (!is_success(status)) { + std::cerr << file << ':' << line << ' ' << tok << " failed in HIP runtime (" + << static_cast(status) << "): " << error_string(status) + << std::endl; + std::abort(); + } +} + +inline status_t initialize_runtime() { return hipInit(0); } + +inline status_t set_device(int dev_idx) { return hipSetDevice(dev_idx); } + +inline int current_device() { + int dev_idx = -1; + check(hipGetDevice(&dev_idx), "hipGetDevice(&dev_idx)", __FILE__, __LINE__); + return dev_idx; +} + +inline status_t mem_get_info(size_t *free_bytes, size_t *total_bytes) { + return hipMemGetInfo(free_bytes, total_bytes); +} + +inline status_t device_synchronize() { return hipDeviceSynchronize(); } + +inline status_t get_vmm_support(int *supports_vmm, int dev_idx) { + return hipDeviceGetAttribute( + supports_vmm, hipDeviceAttributeVirtualMemoryManagementSupported, + dev_idx); +} + +inline allocation_prop_t make_pinned_device_allocation_prop(int dev_idx) { + allocation_prop_t prop{}; + prop.type = hipMemAllocationTypePinned; + prop.requestedHandleType = hipMemHandleTypeNone; + prop.location.type = hipMemLocationTypeDevice; + prop.location.id = dev_idx; + return prop; +} + +inline access_desc_t make_device_rw_access_desc(int dev_idx) { + access_desc_t desc{}; + desc.location.type = hipMemLocationTypeDevice; + desc.location.id = dev_idx; + desc.flags = hipMemAccessFlagsProtReadWrite; + return desc; +} + +inline status_t get_allocation_granularity(size_t *granularity, + const allocation_prop_t *prop) { + return hipMemGetAllocationGranularity(granularity, prop, + hipMemAllocationGranularityMinimum); +} + +inline status_t mem_create(allocation_handle_t *handle, size_t size, + const allocation_prop_t *prop) { + return hipMemCreate(handle, size, prop, 0ULL); +} + +inline status_t mem_release(allocation_handle_t handle) { + return hipMemRelease(handle); +} + +inline status_t address_reserve(void **ptr, size_t size, size_t alignment, + void *preferred_addr = nullptr) { + return hipMemAddressReserve(ptr, size, alignment, preferred_addr, 0ULL); +} + +inline status_t address_free(void *ptr, size_t size) { + return hipMemAddressFree(ptr, size); +} + +inline status_t mem_map(void *ptr, size_t size, size_t offset, + allocation_handle_t handle) { + return hipMemMap(ptr, size, offset, handle, 0ULL); +} + +inline status_t mem_unmap(void *ptr, size_t size) { + return hipMemUnmap(ptr, size); +} + +inline status_t set_access(void *ptr, size_t size, const access_desc_t *desc, + size_t count) { + return hipMemSetAccess(ptr, size, desc, count); +} + +#elif defined(KVCACHED_USE_CUDA) + +using drv_status_t = CUresult; +using rt_status_t = cudaError_t; +using allocation_handle_t = CUmemGenericAllocationHandle; +using allocation_prop_t = CUmemAllocationProp; +using access_desc_t = CUmemAccessDesc; + +inline const char *backend_name() { return "CUDA"; } + +inline const char *error_string(drv_status_t status) { + const char *err = nullptr; + (void)cuGetErrorString(status, &err); + return err ? err : "unknown CUDA driver error"; +} + +inline const char *error_string(rt_status_t status) { + return cudaGetErrorString(status); +} + +inline bool is_success(drv_status_t status) { return status == CUDA_SUCCESS; } + +inline bool is_success(rt_status_t status) { return status == cudaSuccess; } + +inline void check(drv_status_t status, const char *tok, const char *file, + unsigned line) { + if (!is_success(status)) { + std::cerr << file << ':' << line << ' ' << tok << " failed in CUDA driver (" + << static_cast(status) << "): " << error_string(status) + << std::endl; + std::abort(); + } +} + +inline void check(rt_status_t status, const char *tok, const char *file, + unsigned line) { + if (!is_success(status)) { + std::cerr << file << ':' << line << ' ' << tok + << " failed in CUDA runtime (" << static_cast(status) + << "): " << error_string(status) << std::endl; + std::abort(); + } +} + +inline rt_status_t initialize_runtime() { return cudaFree(0); } + +inline rt_status_t set_device(int dev_idx) { return cudaSetDevice(dev_idx); } + +inline int current_device() { + int dev_idx = -1; + check(cudaGetDevice(&dev_idx), "cudaGetDevice(&dev_idx)", __FILE__, __LINE__); + return dev_idx; +} + +inline rt_status_t mem_get_info(size_t *free_bytes, size_t *total_bytes) { + return cudaMemGetInfo(free_bytes, total_bytes); +} + +inline rt_status_t device_synchronize() { return cudaDeviceSynchronize(); } + +inline drv_status_t get_vmm_support(int *supports_vmm, int dev_idx) { +#if defined(CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED) + constexpr auto attr = CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED; +#else + constexpr auto attr = + CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED; +#endif + return cuDeviceGetAttribute(supports_vmm, attr, + static_cast(dev_idx)); +} + +inline allocation_prop_t make_pinned_device_allocation_prop(int dev_idx) { + allocation_prop_t prop{}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = dev_idx; + return prop; +} + +inline access_desc_t make_device_rw_access_desc(int dev_idx) { + access_desc_t desc{}; + desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + desc.location.id = dev_idx; + desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + return desc; +} + +inline drv_status_t get_allocation_granularity(size_t *granularity, + const allocation_prop_t *prop) { + return cuMemGetAllocationGranularity(granularity, prop, + CU_MEM_ALLOC_GRANULARITY_MINIMUM); +} + +inline drv_status_t mem_create(allocation_handle_t *handle, size_t size, + const allocation_prop_t *prop) { + return cuMemCreate(handle, size, prop, 0ULL); +} + +inline drv_status_t mem_release(allocation_handle_t handle) { + return cuMemRelease(handle); +} + +inline drv_status_t address_reserve(void **ptr, size_t size, size_t alignment, + void *preferred_addr = nullptr) { + return cuMemAddressReserve( + reinterpret_cast(ptr), size, alignment, + reinterpret_cast(preferred_addr), 0ULL); +} + +inline drv_status_t address_free(void *ptr, size_t size) { + return cuMemAddressFree(reinterpret_cast(ptr), size); +} + +inline drv_status_t mem_map(void *ptr, size_t size, size_t offset, + allocation_handle_t handle) { + return cuMemMap(reinterpret_cast(ptr), size, offset, handle, + 0ULL); +} + +inline drv_status_t mem_unmap(void *ptr, size_t size) { + return cuMemUnmap(reinterpret_cast(ptr), size); +} + +inline drv_status_t set_access(void *ptr, size_t size, + const access_desc_t *desc, size_t count) { + return cuMemSetAccess(reinterpret_cast(ptr), size, desc, count); +} + +#endif + +} // namespace gpu_vmm +} // namespace kvcached diff --git a/csrc/inc/impl/torch_utils.ipp b/csrc/inc/impl/torch_utils.ipp index deeadf0b..57f3881e 100644 --- a/csrc/inc/impl/torch_utils.ipp +++ b/csrc/inc/impl/torch_utils.ipp @@ -5,40 +5,40 @@ namespace kvcached { -static inline torch::Dtype torch_dtype_cast(const py::object &dtype) { +static inline c10::ScalarType torch_dtype_cast(const py::object &dtype) { if (dtype.is(py::module_::import("torch").attr("float32"))) - return torch::kFloat32; + return c10::ScalarType::Float; if (dtype.is(py::module_::import("torch").attr("float64"))) - return torch::kFloat64; + return c10::ScalarType::Double; if (dtype.is(py::module_::import("torch").attr("float16"))) - return torch::kFloat16; + return c10::ScalarType::Half; if (dtype.is(py::module_::import("torch").attr("int32"))) - return torch::kInt32; + return c10::ScalarType::Int; if (dtype.is(py::module_::import("torch").attr("int64"))) - return torch::kInt64; + return c10::ScalarType::Long; if (dtype.is(py::module_::import("torch").attr("int16"))) - return torch::kInt16; + return c10::ScalarType::Short; if (dtype.is(py::module_::import("torch").attr("int8"))) - return torch::kInt8; + return c10::ScalarType::Char; if (dtype.is(py::module_::import("torch").attr("uint8"))) - return torch::kUInt8; + return c10::ScalarType::Byte; if (dtype.is(py::module_::import("torch").attr("bool"))) - return torch::kBool; + return c10::ScalarType::Bool; throw std::runtime_error("Unsupported dtype: " + py::str(dtype).cast()); } -static inline torch::Dtype torch_dtype_from_size(size_t dtype_size) { +static inline c10::ScalarType torch_dtype_from_size(size_t dtype_size) { switch (dtype_size) { case 1: - return torch::kInt8; + return c10::ScalarType::Char; case 2: - return torch::kInt16; + return c10::ScalarType::Short; case 4: - return torch::kInt32; + return c10::ScalarType::Int; case 8: - return torch::kInt64; + return c10::ScalarType::Long; default: throw std::runtime_error("Unsupported dtype size: " + std::to_string(dtype_size)); diff --git a/csrc/inc/mem_info_tracker.hpp b/csrc/inc/mem_info_tracker.hpp index 176d8040..cc316653 100644 --- a/csrc/inc/mem_info_tracker.hpp +++ b/csrc/inc/mem_info_tracker.hpp @@ -13,7 +13,7 @@ #include #include -#include "cuda_utils.hpp" +#include "gpu_utils.hpp" namespace kvcached { diff --git a/csrc/inc/page.hpp b/csrc/inc/page.hpp index eb0fc93f..43311cf3 100644 --- a/csrc/inc/page.hpp +++ b/csrc/inc/page.hpp @@ -5,10 +5,8 @@ #include -#include -#include - #include "constants.hpp" +#include "gpu_vmm.hpp" namespace kvcached { @@ -28,9 +26,9 @@ class GPUPage : public Page { private: page_id_t page_id_; - CUdevice dev_; + int dev_idx_; size_t page_size_; - CUmemGenericAllocationHandle handle_; + gpu_vmm::allocation_handle_t handle_; }; class CPUPage : public Page { diff --git a/csrc/inc/page_allocator.hpp b/csrc/inc/page_allocator.hpp index ff52cca2..b8d0405f 100644 --- a/csrc/inc/page_allocator.hpp +++ b/csrc/inc/page_allocator.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include #include @@ -15,9 +16,6 @@ #include #include -#include -#include - #include "constants.hpp" #include "mem_info_tracker.hpp" #include "page.hpp" diff --git a/csrc/inc/torch_utils.hpp b/csrc/inc/torch_utils.hpp index 17e0fd9a..4702b87f 100644 --- a/csrc/inc/torch_utils.hpp +++ b/csrc/inc/torch_utils.hpp @@ -3,15 +3,15 @@ #pragma once -#include +#include +#include #include -#include namespace kvcached { namespace py = pybind11; -static inline torch::Dtype torch_dtype_cast(const py::object &dtype); -static inline torch::Dtype torch_dtype_from_size(size_t dtype_size); +static inline c10::ScalarType torch_dtype_cast(const py::object &dtype); +static inline c10::ScalarType torch_dtype_from_size(size_t dtype_size); } // namespace kvcached diff --git a/csrc/page.cpp b/csrc/page.cpp index 55d51de9..70a41877 100644 --- a/csrc/page.cpp +++ b/csrc/page.cpp @@ -1,46 +1,26 @@ // SPDX-FileCopyrightText: Copyright contributors to the kvcached project // SPDX-License-Identifier: Apache-2.0 -#include - -#include "constants.hpp" -#include "cuda_utils.hpp" #include "page.hpp" +#include "constants.hpp" +#include "gpu_utils.hpp" namespace kvcached { GPUPage::GPUPage(page_id_t page_id, int dev_idx, size_t page_size) - : page_id_(page_id), dev_(dev_idx), - page_size_(page_size > 0 ? page_size : kPageSize), handle_(0) { - // CHECK_DRV(cuCtxGetDevice(&dev_)); - - CUmemAllocationProp prop = { - .type = CU_MEM_ALLOCATION_TYPE_PINNED, - .location = - { - .type = CU_MEM_LOCATION_TYPE_DEVICE, - .id = dev_, - }, - }; - CHECK_DRV(cuMemCreate(&handle_, page_size_, &prop, 0)); + : page_id_(page_id), dev_idx_(dev_idx), + page_size_(page_size > 0 ? page_size : kPageSize), handle_() { + auto prop = gpu_vmm::make_pinned_device_allocation_prop(dev_idx_); + CHECK_GPU(gpu_vmm::mem_create(&handle_, page_size_, &prop)); } -GPUPage::~GPUPage() { CHECK_DRV(cuMemRelease(handle_)); } +GPUPage::~GPUPage() { CHECK_GPU(gpu_vmm::mem_release(handle_)); } bool GPUPage::map(generic_ptr_t vaddr, bool set_access) { - CUmemAccessDesc accessDesc_{ - .location = - { - .type = CU_MEM_LOCATION_TYPE_DEVICE, - .id = dev_, - }, - .flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE, - }; - CHECK_DRV(cuMemMap(reinterpret_cast(vaddr), page_size_, 0, - handle_, 0)); + auto access_desc = gpu_vmm::make_device_rw_access_desc(dev_idx_); + CHECK_GPU(gpu_vmm::mem_map(vaddr, page_size_, 0, handle_)); if (set_access) - CHECK_DRV(cuMemSetAccess(reinterpret_cast(vaddr), page_size_, - &accessDesc_, 1)); + CHECK_GPU(gpu_vmm::set_access(vaddr, page_size_, &access_desc, 1)); return true; } diff --git a/csrc/page_allocator.cpp b/csrc/page_allocator.cpp index 1e3fbc85..81ec98e8 100644 --- a/csrc/page_allocator.cpp +++ b/csrc/page_allocator.cpp @@ -12,9 +12,8 @@ #include #include "allocator.hpp" -#include "cuda_utils.hpp" +#include "gpu_utils.hpp" #include "mem_info_tracker.hpp" -#include "torch_utils.hpp" namespace kvcached { @@ -441,8 +440,8 @@ int64_t PageAllocator::get_num_reserved_pages() const { } int64_t PageAllocator::get_avail_physical_pages() const { - size_t avail_phy_mem_size, total_phy_mem_size; - cudaMemGetInfo(&avail_phy_mem_size, &total_phy_mem_size); + size_t avail_phy_mem_size = 0, total_phy_mem_size = 0; + CHECK_GPU(gpu_vmm::mem_get_info(&avail_phy_mem_size, &total_phy_mem_size)); size_t headroom = total_phy_mem_size * (1.0 - gpu_utilization_); avail_phy_mem_size = @@ -668,9 +667,9 @@ void PageAllocator::unmap_pages(const std::vector &page_ids) { // callback broadcast_unmap_callback_(world_size_, offsets); } else { - // Need to synchronize CUDA first in async scheduling mode + // Need to synchronize first in async scheduling mode if (async_sched_) { - CHECK_RT(cudaDeviceSynchronize()); + CHECK_GPU(gpu_vmm::device_synchronize()); } auto allocator = FTensorAllocator::global_allocator(group_id_); bool success = allocator->unmap_from_kv_tensors(offsets); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b59925da..b4f1d7a5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -6,9 +6,10 @@ #include #include #include -#include #include +#include + #include "allocator.hpp" #include "constants.hpp" #include "page_allocator.hpp" @@ -27,7 +28,7 @@ void shutdown_kvcached() { FTensorAllocator::shutdown(); } -std::vector +std::vector create_kv_tensors(size_t size, size_t dtype_size, const std::string &dev_str, int64_t num_layers, int64_t num_kv_buffers = 2, int64_t group_id = 0, bool unified_pool = false) { diff --git a/kvcached/integration/sglang/interfaces.py b/kvcached/integration/sglang/interfaces.py index 88b47602..da332273 100644 --- a/kvcached/integration/sglang/interfaces.py +++ b/kvcached/integration/sglang/interfaces.py @@ -8,7 +8,7 @@ from kvcached.kv_cache_manager import KVCacheManager from kvcached.tp_ipc_util import start_worker_listener_thread -from kvcached.utils import CONTIGUOUS_LAYOUT, PAGE_SIZE, get_kvcached_logger +from kvcached.utils import CONTIGUOUS_LAYOUT, PAGE_SIZE, get_kvcached_logger, normalize_gpu_device from kvcached.vmm_ops import ( create_kv_tensors, init_kvcached as _init_kvcached_impl, @@ -38,6 +38,7 @@ def init_kvcached( if device is None: device = f"cuda:{torch.cuda.current_device()}" + device = normalize_gpu_device(device) _init_kvcached_impl(device, PAGE_SIZE, _contiguous_layout) _kvcached_initialized = True @@ -93,7 +94,8 @@ def alloc_kv_cache( if len(kvcache_shape) <= 2: raise ValueError(f"Unsupported kv cache shape: {kvcache_shape}") - assert torch.cuda.is_available(), "CUDA is not available." + assert torch.cuda.is_available(), "GPU backend is not available via torch.cuda." + device = normalize_gpu_device(device) # SGLang named it "page" to be consistent with PagedAttention. But we call # it "block" to distinguish a KV cache block and a physical memory page. @@ -220,7 +222,8 @@ def alloc_mamba_states( raise RuntimeError( "kvcached is not initialized. Please call init_kvcached() first.") - assert torch.cuda.is_available(), "CUDA is not available." + assert torch.cuda.is_available(), "GPU backend is not available via torch.cuda." + device = normalize_gpu_device(device) conv_shapes = [tuple(s) for s in cache_params.shape.conv] temporal_shape = tuple(cache_params.shape.temporal) diff --git a/kvcached/integration/sglang/patches.py b/kvcached/integration/sglang/patches.py index bc02b218..8ee71c1c 100644 --- a/kvcached/integration/sglang/patches.py +++ b/kvcached/integration/sglang/patches.py @@ -8,7 +8,7 @@ import inspect import math import types -from typing import Any, List, Optional, Tuple, Union, cast +from typing import Any, Callable, List, Optional, Tuple, Union, cast from kvcached.integration.patch_base import BasePatch, enable_kvcached from kvcached.integration.version_utils import VersionAwarePatch, version_range @@ -22,6 +22,11 @@ logger = get_kvcached_logger() +def _is_supported_gpu_device(device: str) -> bool: + device_str = str(device).lower() + return device_str.startswith("cuda") or device_str.startswith("hip") + + class ElasticAllocatorPatch(VersionAwarePatch, BasePatch): """Inject ElasticTokenToKVPoolAllocator into SGLang's allocator module""" @@ -71,8 +76,11 @@ def __init__(self, size: int, dtype, device: str, kvcache, *args, **kwargs) -> N super().__init__(size, 1, dtype, device, kvcache, *args, **kwargs) if not hasattr(kvcache, "kvcached_allocator"): raise ValueError("ElasticTokenToKVPoolAllocator requires elastic MHA pool") - if "cuda" not in device: - raise ValueError("ElasticTokenToKVPoolAllocator only supports cuda device") + if not _is_supported_gpu_device(device): + raise ValueError( + "ElasticTokenToKVPoolAllocator only supports GPU " + "devices (cuda/hip)" + ) self.kvcached_allocator = kvcache.kvcached_allocator logger.info( f"[kvcached] ElasticTokenToKVPoolAllocator in use: size={size} " @@ -158,9 +166,10 @@ def __init__( raise ValueError( "ElasticPagedTokenToKVPoolAllocator requires elastic MHA pool" ) - if "cuda" not in device: + if not _is_supported_gpu_device(device): raise ValueError( - "ElasticPagedTokenToKVPoolAllocator only supports cuda device" + "ElasticPagedTokenToKVPoolAllocator only supports GPU " + "devices (cuda/hip)" ) self.kvcached_allocator = kvcache.kvcached_allocator self.num_pages = size // page_size @@ -449,8 +458,10 @@ def _create_buffers(self): # Initialize kvcached with overlap scheduling to be conservative kvi.init_kvcached(tp_rank=tp_rank, world_size=tp_size, pp_rank=pp_rank, async_sched=True) - if "cuda" not in self.device: - raise ValueError("ElasticMHATokenToKVPool only supports cuda device") + if not _is_supported_gpu_device(self.device): + raise ValueError( + "ElasticMHATokenToKVPool only supports GPU devices " + "(cuda/hip)") _kv_mha: Tuple[List[Any], List[Any]] = cast( Tuple[List[Any], List[Any]], kvi.alloc_kv_cache( @@ -612,8 +623,10 @@ def __init__( kvi.init_kvcached(tp_rank=tp_rank, world_size=tp_size, pp_rank=pp_rank, async_sched=True) - if "cuda" not in device: - raise ValueError("ElasticMLATokenToKVPool only supports cuda device") + if not _is_supported_gpu_device(device): + raise ValueError( + "ElasticMLATokenToKVPool only supports GPU devices " + "(cuda/hip)") self.kv_buffer = cast( List[torch.Tensor], kvi.alloc_kv_cache( @@ -861,9 +874,9 @@ def __init__( pp_rank=pp_rank, async_sched=True, ) - if "cuda" not in device: + if not _is_supported_gpu_device(device): raise ValueError( - "ElasticMambaPool only supports cuda device") + "ElasticMambaPool only supports GPU devices (cuda/hip)") self._group_id = ElasticMambaPool._next_group_id ElasticMambaPool._next_group_id += 1 @@ -1213,41 +1226,74 @@ def apply(self, sched_mod: types.ModuleType) -> bool: @version_range(SGLANG_ALL_RANGE) def patch_scheduler_memory_leak(self, sched_mod: types.ModuleType) -> bool: - """Patch scheduler to suppress memory leak check when kvcached is enabled""" + """Patch scheduler to suppress memory leak check when kvcached is enabled. + + kvcached maps physical KV pages lazily, so SGLang's static-pool + invariant (total == available + in-use) does not hold and its leak + detector would raise spuriously. We neutralize the leak *raisers*. + + Older SGLang keeps the whole check in a single Scheduler method whose + source mentions ``token_to_kv_pool_allocator``. Newer SGLang + (>=0.5.11) moved it into ``SchedulerRuntimeCheckerMixin`` and split it + across several small methods (e.g. ``_check_req_pool`` raises directly, + ``_report_leak`` is the generic choke point for *token/KV* pool leaks). + + We suppress only the leak checks for pools kvcached actually manages + (the KV / token pools). A check that is specific to + ``req_to_token_pool`` is deliberately left intact -- kvcached does not + manage the request pool, its invariant still holds, and silencing it + would hide a genuine request-pool leak. The old single-method layout + (which names ``token_to_kv_pool_allocator``) and the new generic + reporter (which names no pool, and is only ever called for the token + pools) are both kept; only the req-pool-specific check is skipped. + """ Scheduler = self._get_target_class(sched_mod) if Scheduler is None: return False - target_method_name: Union[str, None] = None + target_method_names: List[str] = [] for name, fn in inspect.getmembers(Scheduler, predicate=inspect.isfunction): try: src = inspect.getsource(fn) except Exception: continue - if "token_to_kv_pool_allocator memory leak detected!" in src or ( - "memory leak detected" in src and "token_to_kv_pool_allocator" in src - ): - target_method_name = name - break + if "memory leak detected" not in src: + continue + # Skip a check that is specific to the request pool, which kvcached + # does not manage. The generic reporter names no pool (so it is not + # excluded) and the legacy combined check names the KV allocator. + if "req_to_token_pool" in src and "token_to_kv_pool" not in src: + continue + target_method_names.append(name) - if target_method_name is None: + if not target_method_names: self.logger.debug("No memory leak detection method found in Scheduler") return False - original = getattr(Scheduler, target_method_name) - if self._is_already_patched(original): - self.logger.debug("Scheduler memory leak check already patched") - return True + def _make_wrapped(original: Callable[..., Any]) -> Callable[..., Any]: + def _wrapped(self, *args: Any, **kwargs: Any): + # Disable memory leak detection when ENABLE_KVCACHED is set + if enable_kvcached(): + return + return original(self, *args, **kwargs) + + return _wrapped + + patched_any = False + for target_method_name in target_method_names: + original = getattr(Scheduler, target_method_name) + if self._is_already_patched(original): + self.logger.debug( + f"Scheduler.{target_method_name} leak check already patched") + patched_any = True + continue - def _wrapped(self, *args: Any, **kwargs: Any): - # Disable memory leak detection when ENABLE_KVCACHED is set - if enable_kvcached(): - return - return original(self, *args, **kwargs) + wrapped = _make_wrapped(original) + self._mark_as_patched(wrapped) + setattr(Scheduler, target_method_name, wrapped) + patched_any = True - self._mark_as_patched(_wrapped) - setattr(Scheduler, target_method_name, _wrapped) - return True + return patched_any class RadixCacheLimitPatch(VersionAwarePatch, BasePatch): diff --git a/kvcached/integration/version_utils.py b/kvcached/integration/version_utils.py index 0de036a2..ed65fc1e 100644 --- a/kvcached/integration/version_utils.py +++ b/kvcached/integration/version_utils.py @@ -154,6 +154,16 @@ def detect_version(self, library_name: str, force_refresh: bool = False) -> Opti except Exception as e: self.logger.warning(f"Error detecting version for {library_name}: {e}") + # Fallback to installed-package metadata. Some builds (e.g. source + # builds of SGLang) don't expose a module-level __version__, but the + # distribution metadata still carries the version. + if detected_version is None: + try: + import importlib.metadata as _md + detected_version = _md.version(library_name) + except Exception: + pass + self._version_cache[library_name] = detected_version return detected_version diff --git a/kvcached/integration/vllm/interfaces.py b/kvcached/integration/vllm/interfaces.py index 1168436e..ce2e5ae3 100644 --- a/kvcached/integration/vllm/interfaces.py +++ b/kvcached/integration/vllm/interfaces.py @@ -8,7 +8,7 @@ from kvcached.kv_cache_manager import KVCacheManager from kvcached.tp_ipc_util import start_worker_listener_thread -from kvcached.utils import CONTIGUOUS_LAYOUT, PAGE_SIZE, get_kvcached_logger +from kvcached.utils import CONTIGUOUS_LAYOUT, PAGE_SIZE, get_kvcached_logger, normalize_gpu_device from kvcached.vmm_ops import ( create_kv_tensors, init_kvcached as _init_kvcached_impl, @@ -55,6 +55,7 @@ def init_kvcached( if device is None: device = f"cuda:{torch.cuda.current_device()}" + device = normalize_gpu_device(device) _init_kvcached_impl(device, PAGE_SIZE, _contiguous_layout) _kvcached_initialized = True @@ -193,7 +194,8 @@ def alloc_kv_cache( requested_num_blocks = kvcache_shape[blocks_dim_idx] - assert torch.cuda.is_available(), "CUDA is not available." + assert torch.cuda.is_available(), "GPU backend is not available via torch.cuda." + device = normalize_gpu_device(device) # --- Compute per-layer memory budget and number of blocks --- gpu_mem_bytes = torch.cuda.get_device_properties(device).total_memory diff --git a/kvcached/utils.py b/kvcached/utils.py index 596f0172..7d2de7a6 100644 --- a/kvcached/utils.py +++ b/kvcached/utils.py @@ -139,8 +139,30 @@ def _get_page_size() -> int: # Used by both SGLang (RadixCacheLimitPatch) and vLLM (ElasticBlockPool), # which converts to blocks internally via MAX_CACHED_TOKENS // block_size. MAX_CACHED_TOKENS = int(os.getenv("KVCACHED_MAX_CACHED_TOKENS", "16000")) -CONTIGUOUS_LAYOUT = os.getenv("KVCACHED_CONTIGUOUS_LAYOUT", - "true").lower() == "true" + + +def _default_contiguous_layout() -> bool: + """Default KV-cache layout: contiguous on CUDA, non-contiguous on HIP/ROCm. + + An explicit ``KVCACHED_CONTIGUOUS_LAYOUT`` always wins. Otherwise we pick + non-contiguous on ROCm: the contiguous (compound-page) layout hands the + attention backend strided/interleaved per-layer KV tensors, which vLLM's + ROCm attention path (``split_kv_cache`` + paged kernels) reads incorrectly, + whereas CUDA's FlashAttention/FlashInfer tolerate it. + """ + explicit = os.getenv("KVCACHED_CONTIGUOUS_LAYOUT") + if explicit is not None: + return explicit.lower() == "true" + try: + import torch + if getattr(torch.version, "hip", None): + return False # ROCm/HIP: non-contiguous is required for correctness + except Exception: + pass + return True + + +CONTIGUOUS_LAYOUT = _default_contiguous_layout() DEFAULT_IPC_NAME = _obtain_default_ipc_name() SHM_DIR = "/dev/shm" @@ -158,6 +180,19 @@ def _get_page_size() -> int: _COLOR_RESET = "\033[0m" +def normalize_gpu_device(device: str) -> str: + """Map a ``hip[:N]`` device string to ``cuda[:N]``. + + PyTorch-ROCm and the C++ extension (``c10::Device``) address AMD GPUs as + ``cuda``; kvcached's integration accepts ``hip`` strings, so normalize them + before handing the device to any ``torch.cuda`` API or ``create_kv_tensors``. + """ + dev = str(device) + if dev.lower().startswith("hip"): + return "cuda" + dev[3:] + return dev + + def align_to(x: int, a: int) -> int: return (x + a - 1) // a * a diff --git a/setup.py b/setup.py index 3ea4ce83..f6dc21ec 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ import torch from torch.utils.cpp_extension import ( BuildExtension, + CppExtension, CUDAExtension, include_paths, library_paths, @@ -45,21 +46,63 @@ def get_extensions(): # Get the C++ ABI flag from PyTorch cxx_abi = torch._C._GLIBCXX_USE_CXX11_ABI + is_hip_build = bool(getattr(torch.version, "hip", None)) + is_cuda_build = bool(getattr(torch.version, "cuda", None)) + if is_hip_build: + backend_define = "-DKVCACHED_USE_HIP" + backend_name = "HIP/ROCm" + elif is_cuda_build: + backend_define = "-DKVCACHED_USE_CUDA" + backend_name = "CUDA" + else: + raise RuntimeError( + "Unable to determine GPU backend from PyTorch. " + "Expected either torch.version.hip or torch.version.cuda." + ) + extra_compile_args = [ - "-std=c++17", f"-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}" + "-std=c++17", + f"-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}", + backend_define, ] - vmm_ops_module = CUDAExtension( - "kvcached.vmm_ops", - csrc_files, - include_dirs=include_paths() + [os.path.join(CSRC_PATH, "inc")], - library_dirs=library_paths(), - libraries=["torch", "torch_cpu", "torch_python", "cuda"], - extra_compile_args={ - "cxx": extra_compile_args, - "nvcc": extra_compile_args - }, - ) + ext_include_dirs = include_paths(device_type="cuda") + [ + os.path.join(CSRC_PATH, "inc") + ] + ext_library_dirs = library_paths(device_type="cuda") + + if is_hip_build: + # HIP builds: use CppExtension to avoid PyTorch's hipify step. + # Our code already handles HIP natively via gpu_vmm.hpp conditional + # compilation, so hipify is unnecessary and breaks torch headers. + extra_compile_args.extend([ + "-D__HIP_PLATFORM_AMD__=1", + "-DUSE_ROCM=1", + ]) + ext_libraries = ["amdhip64"] + vmm_ops_module = CppExtension( + "kvcached.vmm_ops", + csrc_files, + include_dirs=ext_include_dirs, + library_dirs=ext_library_dirs, + libraries=ext_libraries, + extra_compile_args={"cxx": extra_compile_args}, + ) + else: + # CUDA driver APIs require libcuda for cuMem* symbols. + ext_libraries = ["cuda"] + vmm_ops_module = CUDAExtension( + "kvcached.vmm_ops", + csrc_files, + include_dirs=ext_include_dirs, + library_dirs=ext_library_dirs, + libraries=ext_libraries, + extra_compile_args={ + "cxx": extra_compile_args, + "nvcc": extra_compile_args, + }, + ) + print(f"Building kvcached.vmm_ops with backend: {backend_name}") return [vmm_ops_module], {"build_ext": BuildExtension} diff --git a/tests/test_elastic_serving.py b/tests/test_elastic_serving.py new file mode 100644 index 00000000..738c1648 --- /dev/null +++ b/tests/test_elastic_serving.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: Copyright contributors to the kvcached project +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end KV-cache elasticity under load (vLLM offline engine). + +Complements ``test_kvcache_manager.py`` (which exercises the manager-level +``resize``/``trim`` APIs directly) by driving the *real* engine and watching the +physically mapped KV footprint grow and shrink through the /dev/shm IPC that +``kvtop``/``kvctl`` read. + +Phases: + 1. idle baseline -> small mapped footprint (lazy) + 2. heavy batch -> footprint GROWS (mem_map on demand) + 3. drain (idle) -> footprint falls as freed blocks are unmapped + 4. forced limit cut -> kvctl-style limit cut (informational; see note) + 5. recover + check -> engine healthy after shrink, output unchanged + +Validated on AMD MI300X (ROCm/HIP) to confirm the hipMemMap (grow) and +hipMemUnmap (shrink) paths; runs on NVIDIA too (device "cuda:0"). + +Run inside the engine venv with kvcached enabled: + ENABLE_KVCACHED=true VLLM_USE_V1=1 python tests/test_elastic_serving.py + +Note: prefix caching MUST be off (enable_prefix_caching=False) or finished +requests keep their KV resident and no shrink is observable. The forced +limit-cut phase is informational only -- with the natural drain already +reclaiming freed pages, it does not independently exercise eviction of *held* +(prefix-cached) blocks; that multi-tenant giveback path needs a dedicated test. +""" +import glob +import hashlib +import os +import threading +import time +from typing import Optional + +from kvcached.cli.utils import get_kv_cache_limit, update_kv_cache_limit + +MODEL = os.getenv("KVCACHED_TEST_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") +MB = 1024 * 1024 + + +def list_segments(): + return {os.path.basename(p) for p in glob.glob("/dev/shm/kvcached_*")} + + +def read_seg(name): + mi = get_kv_cache_limit(name) + return None if mi is None else (mi.total_size, mi.used_size, mi.prealloc_size) + + +def fmt(v): + return f"{v / MB:8.1f} MB" if v is not None else " n/a" + + +samples: list[tuple[float, int, int, int]] = [] # (t, total, used, prealloc) +seg_name: list[Optional[str]] = [None] # this run's segment, set after init +stop = threading.Event() + + +def sampler(t0): + while not stop.is_set(): + nm = seg_name[0] + if nm is not None: + v = read_seg(nm) + if v is not None: + samples.append((time.time() - t0, *v)) + time.sleep(0.2) + + +def used_now(): + nm = seg_name[0] + v = read_seg(nm) if nm else None + return v[1] if v else None + + +def peak_used(t_lo, t_hi): + xs = [u for (t, _t, u, _p) in samples if t_lo <= t <= t_hi] + return max(xs) if xs else None + + +def main(): + pre = list_segments() + t0 = time.time() + threading.Thread(target=sampler, args=(t0,), daemon=True).start() + + print("=== building offline vLLM engine (kvcached) ===", flush=True) + from vllm import LLM, SamplingParams + llm = LLM( + model=MODEL, + enforce_eager=True, + gpu_memory_utilization=0.40, + max_model_len=8192, + enable_prefix_caching=False, # required: else freed KV stays resident + disable_log_stats=True, + ) + + for _ in range(50): + new = list_segments() - pre + if new: + seg_name[0] = sorted(new)[0] + break + time.sleep(0.2) + print(f"[ipc] segment: {seg_name[0]}", flush=True) + assert seg_name[0] is not None, "no kvcached IPC segment detected" + + det = SamplingParams(temperature=0.0, max_tokens=24) + base_txt = llm.generate(["The capital of France is"], det)[0].outputs[0].text + base_md5 = hashlib.md5(base_txt.encode()).hexdigest()[:10] + print(f"[correctness] baseline md5={base_md5} :: {base_txt!r}", flush=True) + + time.sleep(3.0) + base_used = used_now() + print(f"\n[PHASE 1] idle baseline used={fmt(base_used)}", flush=True) + + print("[PHASE 2] heavy batch (grow) ...", flush=True) + prompts = [f"Write a long, detailed essay number {i} about distributed systems, " + f"GPU memory management, and virtual memory paging." for i in range(128)] + load = SamplingParams(temperature=0.7, max_tokens=1024, seed=1234) + t_lo = time.time() - t0 + llm.generate(prompts, load) + t_hi = time.time() - t0 + grow_peak = peak_used(t_lo, t_hi) + print(f"[PHASE 2] peak used during load = {fmt(grow_peak)}", flush=True) + + drain_series = [] + for _ in range(18): + time.sleep(1.0) + drain_series.append(used_now()) + drained = used_now() + print(f"[PHASE 3] after drain used={fmt(drained)}", flush=True) + + cur = read_seg(seg_name[0]) + total_before = cur[0] + small_limit = max(int(max(grow_peak or 0, 256 * MB) // 2), 256 * MB) + print(f"\n[PHASE 4] limit {fmt(total_before)} -> {fmt(small_limit)} " + f"(informational)", flush=True) + update_kv_cache_limit(seg_name[0], small_limit) + time.sleep(10.0) + cur2 = read_seg(seg_name[0]) + print(f"[PHASE 4] after cut total={fmt(cur2[0])} used={fmt(cur2[1])} " + f"prealloc={fmt(cur2[2])}", flush=True) + + update_kv_cache_limit(seg_name[0], total_before) + time.sleep(2.0) + txt2 = llm.generate(["The capital of France is"], det)[0].outputs[0].text + md5_2 = hashlib.md5(txt2.encode()).hexdigest()[:10] + print(f"\n[PHASE 5] post-shrink md5={md5_2} :: {txt2!r}", flush=True) + + stop.set() + time.sleep(0.5) + + grew = (grow_peak or 0) > (base_used or 0) * 1.5 + shrank = drained is not None and grow_peak is not None and drained < grow_peak + correct = md5_2 == base_md5 + print("\n==================== VERDICT ====================", flush=True) + print(f" baseline used : {fmt(base_used)}") + print(f" peak used : {fmt(grow_peak)}") + print(f" drained used : {fmt(drained)}") + print(f" GREW under load ........ {'PASS' if grew else 'FAIL'}") + print(f" SHRANK on free ......... {'PASS' if shrank else 'FAIL'}") + print(f" CORRECT after cycle .... {'PASS' if correct else 'FAIL'} " + f"(base={base_md5} post={md5_2})") + print("=================================================", flush=True) + assert grew and shrank and correct, "elasticity check failed" + + +if __name__ == "__main__": + main()