Skip to content

Commit 26f24c9

Browse files
ColinPepplerfacebook-github-bot
authored andcommitted
dequantize_fp8_cache_kernel: Move D=128 device-side-assertion check to host
Summary: ## What Move the device-side assertions to the host since all the kernels share the same assertion. ## Why When running evals with symmetric quantization, I ran into the following error. > CUDA error: too many resources requested for launch It failed with this launch configuration: blockDim = (32, 32) = 1024 threads per block. - `$ cuobjdump --dump-resource-usage kv_cache.cu.pic.o.sm_90.cubin | c++filt | grep -A 1 'dequantize_fp8_cache_kernel'` gives me - `void fbgemm_gpu::dequantize_fp8_cache_kernel<true, true>... REG:66` - P1908720668 - That means one threadblock has 66 * 1024 = 67584 registers which exceeds the limit of 65,536. Differential Revision: D82320518
1 parent 016a79a commit 26f24c9

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_dequantize.cu

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,6 @@ __global__ void dequantize_fp8_cache_kernel(
188188
auto MAX_T = cache_K.size(1);
189189
auto D_H = cache_K_dq.size(3);
190190
auto D_H_q = cache_K.size(3);
191-
// TODO: support D_H < 128 for small model used in testing.
192-
CUDA_KERNEL_ASSERT(D_H == 128);
193191
const uint8_t offset_bytes = (ExternalQParam) ? 0 : 4;
194192
CUDA_KERNEL_ASSERT(D_H_q - D_H == offset_bytes);
195193

@@ -301,8 +299,6 @@ __global__ void dequantize_fp8_cache_kernel(
301299
auto MAX_T = cache_K.size(1);
302300
auto D_H = cache_K_dq.size(3);
303301
auto D_H_q = cache_K.size(3);
304-
// TODO: support D_H < 128 for small model used in testing.
305-
CUDA_KERNEL_ASSERT(D_H == 128);
306302
const uint8_t offset_bytes = (ExternalQParam) ? 0 : 4;
307303
CUDA_KERNEL_ASSERT(D_H_q - D_H == offset_bytes);
308304

@@ -401,7 +397,6 @@ __global__ void dequantize_fp8_cache_kernel_paged(
401397
auto N_KVH = cache_K.size(2);
402398
auto D_H = cache_K_dq.size(3);
403399
auto D_H_q = cache_K.size(3);
404-
CUDA_KERNEL_ASSERT(D_H == 128);
405400

406401
auto b = blockIdx.x;
407402
// only need to dequantize this far.
@@ -518,6 +513,9 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
518513
}
519514
auto D_H = (D_HQ - fp8_qparam_offset);
520515

516+
// TODO: support D_H < 128 for small model used in testing.
517+
TORCH_CHECK(D_H == 128, "D_H must be 128, got ", D_H);
518+
521519
// TODO:
522520
// The below allocates Tensors that have the same shape as cache_K and
523521
// cache_V to store their dequantize results. For paged KV cache, this can

fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,50 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
357357
cache_v[:, :T], cache_v_bf16[:, :T], atol=1.0e-2, rtol=5.0e-2
358358
)
359359

360+
@settings(deadline=None)
361+
@unittest.skipIf(
362+
not torch.cuda.is_available()
363+
or (
364+
torch.version.cuda
365+
and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9
366+
)
367+
or (torch.version.hip and torch.version.hip < "6.2")
368+
or not HAS_XFORMERS,
369+
"Skip when H100 is not available or MI300 is not available",
370+
)
371+
def test_dequantize_fp8_cache_too_many_resources_for_launch(self) -> None:
372+
# With heavy register usage, dequantize_fp8_cache can fail with
373+
# CUDA error: too many resources requested for launch
374+
device = "cuda"
375+
376+
# Shapes/dtypes
377+
B, MAX_T, N_KVH, D = 1, 139_264, 1, 128
378+
379+
cache_k = torch.randint(
380+
low=0, high=256, size=(B, MAX_T, N_KVH, D), dtype=torch.uint8, device=device
381+
)
382+
cache_v = torch.randint(
383+
low=0, high=256, size=(B, MAX_T, N_KVH, D), dtype=torch.uint8, device=device
384+
)
385+
386+
# Per-token qparams (symmetric=True implies zp=0)
387+
qparam_k = torch.zeros((B, MAX_T, N_KVH, 1), dtype=torch.int32, device=device)
388+
qparam_v = torch.zeros((B, MAX_T, N_KVH, 1), dtype=torch.int32, device=device)
389+
390+
# Sequence length (single int32)
391+
seq_len = torch.tensor([MAX_T], dtype=torch.int32, device=device)
392+
393+
torch.ops.fbgemm.dequantize_fp8_cache( # type: ignore[reportCallIssue]
394+
cache_k,
395+
cache_v,
396+
seq_len,
397+
qparam_k=qparam_k,
398+
qparam_v=qparam_v,
399+
block_tables=None,
400+
page_size=0,
401+
symmetric=True,
402+
)
403+
360404
@settings(deadline=None)
361405
@given(
362406
MAX_T=st.sampled_from([8000, 16384]),

0 commit comments

Comments
 (0)