Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added tests/unit_tests/ops/data/awq/input.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/awq/output.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/awq/qweight.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/awq/qzeros.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/awq/scales.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/unit_tests/ops/data/fp8/linear_input.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/fp8/linear_output.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/fp8/linear_weight.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/unit_tests/ops/data/fp8/moe_output.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/unit_tests/ops/data/fused_moe/output.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/unit_tests/ops/data/gptq/input.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/gptq/output.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/gptq/qweight.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/gptq/qzeros.pt
Binary file not shown.
Binary file added tests/unit_tests/ops/data/gptq/scales.pt
Binary file not shown.
55 changes: 55 additions & 0 deletions tests/unit_tests/ops/test_hpu_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import habana_frameworks.torch as htorch
from utils import get_data_path
from vllm_gaudi.ops.hpu_awq import AWQHPULinearMethod, AWQHPUConfig
from vllm_gaudi.utils import HPUCompileConfig
from vllm.model_executor.layers.linear import RowParallelLinear


def test_awq_linear_method(dist_init):
config = {"bits": 4, "group_size": 128, "zero_point": True}
oot_quant_config = AWQHPUConfig.from_config(config)

# Prepare linear layer with oot AWQHPULinearMethod
oot_op = RowParallelLinear(input_size=256,
output_size=128,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
params_dtype=torch.bfloat16,
reduce_results=True,
quant_config=oot_quant_config,
return_bias=False,
disable_tp=False).to("hpu")
assert isinstance(oot_op.quant_method, AWQHPULinearMethod)

# qweight, qzeros, scales were extracted from first RowParallelLinear of TheBloke/Llama-2-7B-Chat-AWQ
# (with adjusted shape, to make tensors smaller)
qweight = torch.load(get_data_path("data/awq/qweight.pt"), weights_only=False, map_location="hpu")
oot_op.qweight.copy_(qweight)
qzeros = torch.load(get_data_path("data/awq/qzeros.pt"), weights_only=False, map_location="hpu")
oot_op.qzeros.copy_(qzeros)
scales = torch.load(get_data_path("data/awq/scales.pt"), weights_only=False, map_location="hpu").to(torch.bfloat16)
oot_op.scales.copy_(scales)

oot_op.quant_method.process_weights_after_loading(oot_op)

if not htorch.utils.internal.is_lazy():
compile_config = HPUCompileConfig()
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())

# Input and expected output
# Output tensor holds the data that was returned by cuda implementation of AWQLinearMethod for given input
# (AWQLinearMethod was triggered offline with the same input as below to get the ref_output)
input = torch.load(get_data_path("data/awq/input.pt"), weights_only=False, map_location="hpu").to(torch.bfloat16)
ref_output = torch.load(get_data_path("data/awq/output.pt"), weights_only=False,
map_location="hpu").to(torch.bfloat16)

# Execute layer
out = oot_op(input)

# Check correctness
torch.testing.assert_close(ref_output, out, atol=1e-3, rtol=1e-3)
304 changes: 304 additions & 0 deletions tests/unit_tests/ops/test_hpu_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import habana_frameworks.torch as htorch
from utils import get_data_path
from unittest.mock import MagicMock
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig
from vllm_gaudi.ops.hpu_compressed_tensors import (HPUCompressedTensorsLinearMethod, HPUCompressedTensorsW8A8Fp8,
HPUCompressedTensorsWNA16, HPUCompressedTensorsWNA16MoEMethod)
from vllm_gaudi.utils import HPUCompileConfig
from vllm.forward_context import override_forward_context
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.fused_moe.layer import FusedMoE


def test_compressed_tensors_linear_method_w8a8fp8(dist_init):
config = {
'config_groups': {
'group_0': {
'input_activations': {
'block_structure': None,
'dynamic': True,
'group_size': None,
'num_bits': 8,
'observer': 'memoryless',
'observer_kwargs': {},
'strategy': 'token',
'symmetric': True,
'type': 'float'
},
'output_activations': None,
'targets': ['Linear'],
'weights': {
'block_structure': None,
'dynamic': False,
'group_size': None,
'num_bits': 8,
'observer': 'minmax',
'observer_kwargs': {},
'strategy': 'channel',
'symmetric': True,
'type': 'float'
}
}
},
'format': 'naive-quantized',
'global_compression_ratio': 1.239290831149584,
'ignore': [],
'kv_cache_scheme': None,
'quant_method': 'compressed-tensors',
'quantization_status': 'frozen'
}
oot_quant_config = CompressedTensorsConfig.from_config(config)

# Prepare linear layer with oot CompressedTensorsLinearMethod
# with HPUCompressedTensorsW8A8Fp8 scheme
oot_op = RowParallelLinear(input_size=256,
output_size=256,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
params_dtype=torch.bfloat16,
reduce_results=True,
quant_config=oot_quant_config,
return_bias=False,
disable_tp=False).to("hpu")
assert isinstance(oot_op.quant_method, HPUCompressedTensorsLinearMethod)
assert isinstance(oot_op.scheme, HPUCompressedTensorsW8A8Fp8)

# Weight and weight_scale_inv were extracted from first RowParallelLinear
# layer of RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic
# (with adjusted shapes, to make tensors smaller)
weight = torch.load(get_data_path("data/compressed_tensors/linear_w8a8fp8_weight.pt"),
weights_only=False,
map_location="hpu")
oot_op.weight.copy_(weight)
weight_scale = torch.load(get_data_path("data/compressed_tensors/linear_w8a8fp8_weight_scale.pt"),
weights_only=False,
map_location="hpu")
oot_op.weight_scale.copy_(weight_scale)

oot_op.quant_method.process_weights_after_loading(oot_op)

if not htorch.utils.internal.is_lazy():
compile_config = HPUCompileConfig()
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())

# Input and expected output
# Output tensor holds data that was returned by cuda impl of CompressedTensorsLinearMethod for given input
# (CompressedTensorsLinearMethod was triggered offline with the same input as below to get the ref_output)
input = torch.load(get_data_path("data/compressed_tensors/linear_w8a8fp8_input.pt"),
weights_only=False,
map_location="hpu")
ref_output = torch.load(get_data_path("data/compressed_tensors/linear_w8a8fp8_output.pt"),
weights_only=False,
map_location="hpu")

# Execute layer
out = oot_op(input)

# Check correctness
torch.testing.assert_close(ref_output, out, atol=1e-3, rtol=1e-3)


def test_compressed_tensors_linear_method_wna16(dist_init):
config = {
'config_groups': {
'group_0': {
'input_activations': None,
'output_activations': None,
'targets': ['Linear'],
'weights': {
'actorder': 'weight',
'block_structure': None,
'dynamic': False,
'group_size': 128,
'num_bits': 4,
'observer': 'minmax',
'observer_kwargs': {},
'strategy': 'group',
'symmetric': False,
'type': 'int'
}
}
},
'format': 'pack-quantized',
'global_compression_ratio': None,
'ignore': [],
'kv_cache_scheme': None,
'quant_method': 'compressed-tensors',
'quantization_status': 'compressed'
}
oot_quant_config = CompressedTensorsConfig.from_config(config)

# Prepare linear layer with oot CompressedTensorsLinearMethod
# with HPUCompressedTensorsWNA16 scheme
oot_op = RowParallelLinear(input_size=256,
output_size=256,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
params_dtype=torch.bfloat16,
reduce_results=True,
quant_config=oot_quant_config,
return_bias=False,
disable_tp=False).to("hpu")
assert isinstance(oot_op.quant_method, HPUCompressedTensorsLinearMethod)
assert isinstance(oot_op.scheme, HPUCompressedTensorsWNA16)

# Weights were extracted from first RowParallelLinear layer of RedHatAI/Qwen3-8B-quantized.w4a16
# (with adjusted shapes, to make tensors smaller)
weight_packed = torch.load(get_data_path("data/compressed_tensors/linear_wna16_weight_packed.pt"),
weights_only=False,
map_location="hpu")
oot_op.weight_packed.copy_(weight_packed)
weight_scale = torch.load(get_data_path("data/compressed_tensors/linear_wna16_weight_scale.pt"),
weights_only=False,
map_location="hpu")
oot_op.weight_scale.copy_(weight_scale)
weight_zero_point = torch.load(get_data_path("data/compressed_tensors/linear_wna16_weight_zero_point.pt"),
weights_only=False,
map_location="hpu")
oot_op.weight_zero_point.copy_(weight_zero_point)
oot_op.weight_shape.data = torch.tensor([256, 256], device='hpu:0')

oot_op.quant_method.process_weights_after_loading(oot_op)

if not htorch.utils.internal.is_lazy():
compile_config = HPUCompileConfig()
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())

# Input and expected output
# Output tensor holds data that was returned by cuda impl of CompressedTensorsLinearMethod for given input
# (CompressedTensorsLinearMethod was triggered offline with the same input as below to get the ref_output)
input = torch.load(get_data_path("data/compressed_tensors/linear_wna16_input.pt"),
weights_only=False,
map_location="hpu")
ref_output = torch.load(get_data_path("data/compressed_tensors/linear_wna16_output.pt"),
weights_only=False,
map_location="hpu")

# Execute layer
out = oot_op(input)

# Check correctness
torch.testing.assert_close(ref_output, out, atol=1e-3, rtol=1e-3)


def test_compressed_tensors_wna16_moe_method(dist_init):
config = {
'config_groups': {
'group_0': {
'input_activations': None,
'output_activations': None,
'targets': ['Linear'],
'weights': {
'actorder': 'weight',
'block_structure': None,
'dynamic': False,
'group_size': 128,
'num_bits': 4,
'observer': 'minmax',
'observer_kwargs': {},
'strategy': 'group',
'symmetric': True,
'type': 'int'
}
}
},
'format': 'pack-quantized',
'global_compression_ratio': None,
'ignore': [],
'kv_cache_scheme': None,
'quant_method': 'compressed-tensors',
'quantization_status': 'compressed'
}
oot_quant_config = CompressedTensorsConfig.from_config(config)

# Prepare FusedMoE layer with oot HPUCompressedTensorsWNA16MoEMethod
oot_op = FusedMoE(num_experts=128,
top_k=8,
hidden_size=512,
intermediate_size=256,
params_dtype=torch.bfloat16,
reduce_results=True,
renormalize=True,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
quant_config=oot_quant_config,
tp_size=None,
ep_size=None,
dp_size=None,
custom_routing_function=None,
scoring_func="softmax",
routed_scaling_factor=1.0,
e_score_correction_bias=None,
apply_router_weight_on_input=False,
activation="silu",
enable_eplb=False,
num_redundant_experts=0,
has_bias=False,
is_sequence_parallel=False,
zero_expert_num=0,
zero_expert_type=None).to("hpu")
assert isinstance(oot_op.quant_method, HPUCompressedTensorsWNA16MoEMethod)

# Weights were extracted from first FusedMoE layer of RedHatAI/Qwen3-30B-A3B-quantized.w4a16
# (with adjusted shapes, to make tensors smaller)
w2_weight_packed = torch.load(get_data_path("data/compressed_tensors/moe_wna16_w2_weight_packed.pt"),
weights_only=False,
map_location="hpu")
w2_weight_packed = torch.swapaxes(w2_weight_packed, 0, 1).repeat(128, 1, 1)
oot_op.w2_weight_packed.copy_(w2_weight_packed)
w13_weight_packed = torch.load(get_data_path("data/compressed_tensors/moe_wna16_w13_weight_packed.pt"),
weights_only=False,
map_location="hpu")
w13_weight_packed = torch.swapaxes(w13_weight_packed, 0, 1).repeat(128, 1, 1)
oot_op.w13_weight_packed.copy_(w13_weight_packed)

w2_weight_scale = torch.load(get_data_path("data/compressed_tensors/moe_wna16_w2_weight_scale.pt"),
weights_only=False,
map_location="hpu")
w2_weight_scale = torch.swapaxes(w2_weight_scale, 0, 1).repeat(128, 1, 1)
oot_op.w2_weight_scale.copy_(w2_weight_scale)
w13_weight_scale = torch.load(get_data_path("data/compressed_tensors/moe_wna16_w13_weight_scale.pt"),
weights_only=False,
map_location="hpu")
w13_weight_scale = torch.swapaxes(w13_weight_scale, 0, 1).repeat(128, 1, 1)
oot_op.w13_weight_scale.copy_(w13_weight_scale)

w2_weight_shape = torch.tensor([512, 256], dtype=torch.bfloat16, device="hpu")
oot_op.w2_weight_shape.copy_(w2_weight_shape.repeat(128, 1))
w13_weight_shape = torch.tensor([256, 512], dtype=torch.bfloat16, device="hpu")
oot_op.w13_weight_shape.copy_(w13_weight_shape.repeat(128, 1))

oot_op.quant_method.process_weights_after_loading(oot_op)

if not htorch.utils.internal.is_lazy():
compile_config = HPUCompileConfig()
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())

# Input and expected output
# Output tensor holds data that was returned by cuda impl of CompressedTensorsWNA16MarlinMoEMethod for given input
# (CompressedTensorsWNA16MarlinMoEMethod was triggered offline with the same input as below to get the ref_output)
hidden_states = torch.load(get_data_path("data/compressed_tensors/moe_wna16_input_hidden_states.pt"),
weights_only=False,
map_location="hpu")
router_logits = torch.load(get_data_path("data/compressed_tensors/moe_wna16_input_router_logits.pt"),
weights_only=False,
map_location="hpu")
ref_output = torch.load(get_data_path("data/compressed_tensors/moe_wna16_output.pt"),
weights_only=False,
map_location="hpu")

# Execute layer
mock_ctx = MagicMock(spec=["dp_metadata"])
mock_ctx.dp_metadata = None
with override_forward_context(mock_ctx):
out = oot_op.forward_impl(hidden_states, router_logits)

# Check correctness
torch.testing.assert_close(ref_output, out, atol=1e-4, rtol=1e-4)
Loading
Loading