Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ vllm bench serve --model meta-llama/Llama-3.2-1B-Instruct --request-rate 10 --nu
>
> When kvcached is enabled, there is NO need to set memory utilization limit (e.g., using `--gpu-memory-utilization`) as kvcached will automatically manage the memory.

> [!NOTE]
> **AMD / ROCm:** on ROCm (HIP) builds, kvcached automatically defaults to the **non-contiguous** KV-cache layout. The contiguous layout (the default on NVIDIA) hands vLLM's ROCm attention backend strided per-layer KV tensors it cannot read correctly, whereas non-contiguous matches the layout the backend expects. You can override with `KVCACHED_CONTIGUOUS_LAYOUT=true|false`, but contiguous is not recommended on ROCm.

If you installed kvcached using its source code, you can also do the following:

```bash
Expand Down
21 changes: 18 additions & 3 deletions benchmarks/bench_vmm/Makefile
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
NVCC ?= nvcc
INC_DIR = ../../csrc/inc

# Auto-detect backend: set KVCACHED_BACKEND=hip for AMD, defaults to cuda.
KVCACHED_BACKEND ?= cuda

ifeq ($(KVCACHED_BACKEND),hip)
CXX := hipcc
CXXFLAGS = -DKVCACHED_USE_HIP -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1
LDLIBS = -lamdhip64
else
CXX := nvcc
CXXFLAGS = -DKVCACHED_USE_CUDA
LDLIBS = -lcuda
endif

CXXFLAGS += -O2 -g -std=c++17 -I$(INC_DIR)

all: bench_vmm.bin

bench_vmm.bin: bench_vmm.cpp
$(NVCC) $^ -o $@ -O2 -g -lcuda -std=c++17
$(CXX) $^ -o $@ $(CXXFLAGS) $(LDLIBS)

clean:
$(RM) bench_vmm.bin
$(RM) bench_vmm.bin
37 changes: 22 additions & 15 deletions benchmarks/bench_vmm/README.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
# VMM Benchmark

This benchmark measures the latency of various CUDA Virtual Memory Management (VMM) operations.
This benchmark measures the latency of various GPU Virtual Memory Management (VMM) operations on both NVIDIA (CUDA) and AMD (ROCm/HIP) GPUs.

## Description

The tool benchmarks the following CUDA driver API calls:
- `cuMemAddressReserve`: Reserving a virtual address range.
- `cuMemCreate`: Allocating physical memory.
- `cuMemMap`: Mapping physical memory to a virtual address.
- `cuMemSetAccess`: Setting access permissions for a mapped region.
- `cuMemUnmap`: Unmapping physical memory.
The tool benchmarks the following VMM API calls:
- `address_reserve`: Reserving a virtual address range.
- `mem_create`: Allocating physical memory.
- `mem_map`: Mapping physical memory to a virtual address.
- `set_access`: Setting access permissions for a mapped region.
- `mem_unmap`: Unmapping physical memory.

It uses multiple CPU threads to issue these commands in parallel and reports latency statistics (average, p50, p90, p99, and max).

## Building the Benchmark

You need a CUDA-enabled GPU and the CUDA Toolkit installed.
You need a GPU with VMM support and the corresponding toolkit installed (CUDA Toolkit or ROCm).

Compile the benchmark`:
For NVIDIA GPUs (default):

```bash
make
```

For AMD GPUs:

```bash
make KVCACHED_BACKEND=hip
```

## Running the Benchmark

Execute the compiled binary:
Expand All @@ -36,17 +42,18 @@ The benchmark parameters (number of threads, page size, etc.) are defined as `co
## Sample Output on A100

```
Backend: CUDA
Total Free Memory: 84.5442GB
====== cuMemMap ElemSz=1 ======
====== VMM Benchmark ======

cuMemAddressReserve (8GB) latency: 19 us
address_reserve (8GB) latency: 19 us

Benchmarking with 1 threads and 4096 pages of size 2MB.
---------------------------------------------------------------------------
Operation avg (us) p50 (us) p90 (us) p99 (us) max (us)
---------------------------------------------------------------------------
cuMemCreate 193.32 195.00 339.00 381.00 493.00
cuMemMap 1.45 0.00 4.00 5.00 105.00
cuMemSetAccess 35.99 35.00 42.00 54.00 169.00
cuMemUnmap 25.63 25.00 27.00 39.00 126.00
mem_create 193.32 195.00 339.00 381.00 493.00
mem_map 1.45 0.00 4.00 5.00 105.00
set_access 35.99 35.00 42.00 54.00 169.00
mem_unmap 25.63 25.00 27.00 39.00 126.00
```
118 changes: 49 additions & 69 deletions benchmarks/bench_vmm/bench_vmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
#include <thread>
#include <vector>

#include <cuda.h>
#include <cuda_runtime.h>
#include "gpu_utils.hpp"
#include "gpu_vmm.hpp"

#include "cuda_utils.hpp"
namespace vmm = kvcached::gpu_vmm;

static constexpr int kNumThds = 1;
static constexpr size_t kPageSize = 2ul << 20; // MB
Expand All @@ -23,68 +23,56 @@ void print_header();
void print_stats(const std::string &op_name,
const std::vector<double> latencies[kNumThds]);

int init_cuda() {
size_t free;
typedef unsigned char ElemType;
CUcontext ctx;
CUdevice dev;
int supportsVMM = 0;
int init_gpu() {
int supports_vmm = 0;

CHECK_RT(cudaFree(0));
CHECK_GPU(vmm::initialize_runtime());
CHECK_GPU(vmm::set_device(0));
int dev_idx = vmm::current_device();

CHECK_DRV(cuInit(0));
CHECK_DRV(cuDevicePrimaryCtxRetain(&ctx, 0));
CHECK_DRV(cuCtxSetCurrent(ctx));
CHECK_DRV(cuCtxGetDevice(&dev));
size_t free_mem = 0, total_mem = 0;
#if defined(KVCACHED_USE_HIP)
CHECK_GPU(hipMemGetInfo(&free_mem, &total_mem));
#elif defined(KVCACHED_USE_CUDA)
CHECK_GPU(cudaMemGetInfo(&free_mem, &total_mem));
#endif

CHECK_DRV(cuMemGetInfo(&free, NULL));
std::cout << "Total Free Memory: " << (float)free / std::giga::num << "GB"
std::cout << "Backend: " << vmm::backend_name() << std::endl;
std::cout << "Total Free Memory: " << (float)free_mem / std::giga::num << "GB"
<< std::endl;

CHECK_DRV(cuDeviceGetAttribute(
&supportsVMM, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED,
dev));
if (supportsVMM) {
std::cout << "====== cuMemMap ElemSz=" << sizeof(ElemType)
<< " ======" << std::endl;
CHECK_DRV(vmm::get_vmm_support(&supports_vmm, dev_idx));
if (supports_vmm) {
std::cout << "====== VMM Benchmark ======" << std::endl;
} else {
std::cout << "VMM not supported" << std::endl;
}

return 0;
}

CUdeviceptr alloc_virtual(size_t size) {
CUdeviceptr addr;
CHECK_DRV(cuMemAddressReserve(&addr, size, kPageSize, 0, 0));
void *alloc_virtual(size_t size) {
void *addr = nullptr;
CHECK_DRV(vmm::address_reserve(&addr, size, kPageSize));
return addr;
}

int bench_physical_alloc(std::vector<CUmemGenericAllocationHandle> &handles) {
int bench_physical_alloc(std::vector<vmm::allocation_handle_t> &handles) {
std::vector<std::thread> thds;
std::vector<double> latencies[kNumThds];

handles.resize(kNumPages);

CUdevice dev;
CHECK_DRV(cuCtxGetDevice(&dev));

CUmemAllocationProp prop = {
.type = CU_MEM_ALLOCATION_TYPE_PINNED,
.location =
{
.type = CU_MEM_LOCATION_TYPE_DEVICE,
.id = dev,
},
};
int dev_idx = vmm::current_device();
auto prop = vmm::make_pinned_device_allocation_prop(dev_idx);

for (int i = 0; i < kNumThds; i++) {
thds.emplace_back([&, tid = i]() {
auto stt_page = kNumPages / kNumThds * tid;
auto end_page = kNumPages / kNumThds * (tid + 1);
for (size_t page_idx = stt_page; page_idx < end_page; page_idx++) {
auto stt = std::chrono::high_resolution_clock::now();
CHECK_DRV(cuMemCreate(&handles[page_idx], kPageSize, &prop, 0));
CHECK_DRV(vmm::mem_create(&handles[page_idx], kPageSize, &prop));
auto end = std::chrono::high_resolution_clock::now();
latencies[tid].push_back(
std::chrono::duration_cast<std::chrono::microseconds>(end - stt)
Expand All @@ -97,7 +85,7 @@ int bench_physical_alloc(std::vector<CUmemGenericAllocationHandle> &handles) {
thd.join();
}

print_stats("cuMemCreate", latencies);
print_stats("mem_create", latencies);

return 0;
}
Expand Down Expand Up @@ -152,18 +140,18 @@ void print_stats(const std::string &op_name,
<< std::setw(15) << max << std::endl;
}

int bench_mmap(CUdeviceptr addr,
std::vector<CUmemGenericAllocationHandle> &handles) {
int bench_mmap(void *addr, std::vector<vmm::allocation_handle_t> &handles) {
std::vector<std::thread> thds;
std::vector<double> latencies[kNumThds];
char *base = static_cast<char *>(addr);

for (int i = 0; i < kNumThds; i++) {
thds.emplace_back([&, tid = i]() {
auto stt = kNumPages / kNumThds * tid;
auto end = kNumPages / kNumThds * (tid + 1);
for (size_t i = stt; i < end; i++) {
auto stt = std::chrono::high_resolution_clock::now();
CHECK_DRV(cuMemMap(addr + i * kPageSize, kPageSize, 0, handles[i], 0));
CHECK_DRV(vmm::mem_map(base + i * kPageSize, kPageSize, 0, handles[i]));
auto end = std::chrono::high_resolution_clock::now();
latencies[tid].push_back(
std::chrono::duration_cast<std::chrono::microseconds>(end - stt)
Expand All @@ -176,25 +164,18 @@ int bench_mmap(CUdeviceptr addr,
thd.join();
}

print_stats("cuMemMap", latencies);
print_stats("mem_map", latencies);

return 0;
}

int bench_setaccess(CUdeviceptr addr) {
int bench_setaccess(void *addr) {
std::vector<std::thread> thds;
std::vector<double> latencies[kNumThds];
CUdevice dev;

CHECK_DRV(cuCtxGetDevice(&dev));
CUmemAccessDesc accessDesc{
.location =
{
.type = CU_MEM_LOCATION_TYPE_DEVICE,
.id = dev,
},
.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE,
};
char *base = static_cast<char *>(addr);

int dev_idx = vmm::current_device();
auto access_desc = vmm::make_device_rw_access_desc(dev_idx);

for (int i = 0; i < kNumThds; i++) {
thds.emplace_back([&, tid = i]() {
Expand All @@ -203,7 +184,7 @@ int bench_setaccess(CUdeviceptr addr) {
for (size_t i = stt; i < end; i++) {
auto stt = std::chrono::high_resolution_clock::now();
CHECK_DRV(
cuMemSetAccess(addr + i * kPageSize, kPageSize, &accessDesc, 1));
vmm::set_access(base + i * kPageSize, kPageSize, &access_desc, 1));
auto end = std::chrono::high_resolution_clock::now();
latencies[tid].push_back(
std::chrono::duration_cast<std::chrono::microseconds>(end - stt)
Expand All @@ -216,22 +197,23 @@ int bench_setaccess(CUdeviceptr addr) {
thd.join();
}

print_stats("cuMemSetAccess", latencies);
print_stats("set_access", latencies);

return 0;
}

int bench_munmap(CUdeviceptr addr) {
int bench_munmap(void *addr) {
std::vector<std::thread> thds;
std::vector<double> latencies[kNumThds];
char *base = static_cast<char *>(addr);

for (int i = 0; i < kNumThds; i++) {
thds.emplace_back([&, tid = i]() {
auto stt = kNumPages / kNumThds * tid;
auto end = kNumPages / kNumThds * (tid + 1);
for (size_t i = stt; i < end; i++) {
auto stt = std::chrono::high_resolution_clock::now();
CHECK_DRV(cuMemUnmap(addr + i * kPageSize, kPageSize));
CHECK_DRV(vmm::mem_unmap(base + i * kPageSize, kPageSize));
auto end = std::chrono::high_resolution_clock::now();
latencies[tid].push_back(
std::chrono::duration_cast<std::chrono::microseconds>(end - stt)
Expand All @@ -244,34 +226,32 @@ int bench_munmap(CUdeviceptr addr) {
thd.join();
}

print_stats("cuMemUnmap", latencies);
print_stats("mem_unmap", latencies);

return 0;
}

void free_physical(std::vector<CUmemGenericAllocationHandle> &handles) {
void free_physical(std::vector<vmm::allocation_handle_t> &handles) {
for (const auto &handle : handles) {
CHECK_DRV(cuMemRelease(handle));
CHECK_DRV(vmm::mem_release(handle));
}
}

void free_virtual(CUdeviceptr addr) {
CHECK_DRV(cuMemAddressFree(addr, kMemSize));
}
void free_virtual(void *addr) { CHECK_DRV(vmm::address_free(addr, kMemSize)); }

int main() {
init_cuda();
init_gpu();

auto stt = std::chrono::high_resolution_clock::now();
CUdeviceptr addr = alloc_virtual(kMemSize);
void *addr = alloc_virtual(kMemSize);
auto end = std::chrono::high_resolution_clock::now();
auto lat =
std::chrono::duration_cast<std::chrono::microseconds>(end - stt).count();
std::cout << "\ncuMemAddressReserve (" << (kMemSize >> 30)
std::cout << "\naddress_reserve (" << (kMemSize >> 30)
<< "GB) latency: " << lat << " us\n"
<< std::endl;

std::vector<CUmemGenericAllocationHandle> handles;
std::vector<vmm::allocation_handle_t> handles;

print_header();
bench_physical_alloc(handles);
Expand Down
Loading
Loading