-
Notifications
You must be signed in to change notification settings - Fork 63
test: validate fused silu-and-mul fp8 quantization #264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
033e58c
9b3a170
cf62a1b
878333c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |||||||||||||||||
|
|
||||||||||||||||||
| import vllm._custom_ops as ops | ||||||||||||||||||
| from tests.kernels.utils import opcheck | ||||||||||||||||||
| from vllm.config import VllmConfig, set_current_vllm_config | ||||||||||||||||||
| from vllm.model_executor.layers.activation import SiluAndMul | ||||||||||||||||||
| from vllm.platforms import current_platform | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -55,18 +56,21 @@ def test_silu_and_mul( | |||||||||||||||||
| torch.cuda.manual_seed(seed) | ||||||||||||||||||
| torch.set_default_device(device) | ||||||||||||||||||
|
|
||||||||||||||||||
| layer = SiluAndMul() | ||||||||||||||||||
| with set_current_vllm_config(VllmConfig()): | ||||||||||||||||||
| layer = SiluAndMul() | ||||||||||||||||||
|
|
||||||||||||||||||
| # Make inputs | ||||||||||||||||||
| scale = (torch.randn((1), device=device, dtype=torch.float32)) | ||||||||||||||||||
| x = torch.randn(num_tokens, hidden_size, dtype=dtype) | ||||||||||||||||||
| # Make inputs | ||||||||||||||||||
| scale = (torch.randn((1), device=device, dtype=torch.float32)) | ||||||||||||||||||
| x = torch.randn(num_tokens, hidden_size, dtype=dtype) | ||||||||||||||||||
|
|
||||||||||||||||||
| ref_out = ref_impl(layer, x, scale) | ||||||||||||||||||
| ops_out = ops_impl(x, scale) | ||||||||||||||||||
| ref_out = ref_impl(layer, x, scale) | ||||||||||||||||||
| ops_out = ops_impl(x, scale) | ||||||||||||||||||
|
|
||||||||||||||||||
| assert ref_out.dtype == quant_dtype | ||||||||||||||||||
| assert ops_out.dtype == quant_dtype | ||||||||||||||||||
| assert ref_out.shape == ops_out.shape | ||||||||||||||||||
| assert torch.allclose(ref_out.to(dtype=torch.float32), | ||||||||||||||||||
| ops_out.to(dtype=torch.float32)) | ||||||||||||||||||
| opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale)) | ||||||||||||||||||
| assert ref_out.dtype == quant_dtype | ||||||||||||||||||
| assert ops_out.dtype == quant_dtype | ||||||||||||||||||
| assert ref_out.shape == ops_out.shape | ||||||||||||||||||
| assert torch.allclose(ref_out.to(dtype=torch.float32), | ||||||||||||||||||
| ops_out.to(dtype=torch.float32), | ||||||||||||||||||
| atol=1 / 128, | ||||||||||||||||||
| rtol=0) | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using a fixed
Suggested change
|
||||||||||||||||||
| opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale)) | ||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,13 +15,41 @@ | |||||||||||||
| from torch._prims_common import TensorLikeType | ||||||||||||||
|
|
||||||||||||||
| from tests.kernels.quant_utils import native_w8a8_block_matmul | ||||||||||||||
| from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType | ||||||||||||||
| try: | ||||||||||||||
| from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType | ||||||||||||||
| except ModuleNotFoundError: | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching only
Suggested change
|
||||||||||||||
| from enum import Enum | ||||||||||||||
|
|
||||||||||||||
| AttentionBackend = Any | ||||||||||||||
| AttentionMetadata = Any | ||||||||||||||
|
|
||||||||||||||
| class AttentionType(Enum): | ||||||||||||||
| ENCODER_DECODER = "encoder_decoder" | ||||||||||||||
|
|
||||||||||||||
| from vllm.model_executor.layers.activation import SiluAndMul | ||||||||||||||
| from vllm.model_executor.layers.fused_moe.utils import ( | ||||||||||||||
| moe_kernel_quantize_input) | ||||||||||||||
| from vllm.platforms.interface import _Backend | ||||||||||||||
| from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, | ||||||||||||||
| STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) | ||||||||||||||
| try: | ||||||||||||||
| from vllm.platforms.interface import _Backend | ||||||||||||||
| except ImportError: | ||||||||||||||
| _Backend = Any | ||||||||||||||
| try: | ||||||||||||||
| from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, | ||||||||||||||
| STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) | ||||||||||||||
| except ImportError: | ||||||||||||||
| STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" | ||||||||||||||
| STR_FLASH_ATTN_VAL = "FLASH_ATTN" | ||||||||||||||
| STR_XFORMERS_ATTN_VAL = "XFORMERS" | ||||||||||||||
|
|
||||||||||||||
| def make_tensor_with_pad( | ||||||||||||||
| x: list[list[int]], | ||||||||||||||
| max_len: int, | ||||||||||||||
| pad: int, | ||||||||||||||
| dtype: torch.dtype, | ||||||||||||||
| device: Union[torch.device, str], | ||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||
| padded = [item + [pad] * (max_len - len(item)) for item in x] | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If any list in
Suggested change
|
||||||||||||||
| return torch.tensor(padded, dtype=dtype, device=device) | ||||||||||||||
|
|
||||||||||||||
| # For now, disable "test_aot_dispatch_dynamic" since there are some | ||||||||||||||
| # bugs related to this test in PyTorch 2.4. | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The quantization scale is generated using
torch.randn, which can produce negative or zero values. Quantization scales must be strictly positive. A negative scale will incorrectly flip the signs of the quantized values, and a scale close to zero can cause division-by-zero or overflow issues. Consider using the absolute value of the random tensor and adding a small epsilon to ensure a valid, positive scale.