Skip to content

Commit 763269a

Browse files
authored
[ET-VK][Ops] dequantize ops skeleton test framework
Differential Revision: D76267021 Pull Request resolved: #11480
1 parent 9cc9c35 commit 763269a

File tree

2 files changed

+191
-0
lines changed

2 files changed

+191
-0
lines changed
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <ATen/ATen.h>
12+
13+
#include <executorch/backends/vulkan/runtime/api/api.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16+
17+
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
18+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
19+
20+
#include "test_utils.h"
21+
22+
#include <cassert>
23+
#include <iostream>
24+
25+
namespace torch {
26+
namespace executor {
27+
namespace native {
28+
29+
// Forward declarations of the functions we're testing
30+
Tensor& dequantize_per_tensor_out(
31+
const Tensor& input,
32+
double scale,
33+
int64_t zero_point,
34+
int64_t quant_min,
35+
int64_t quant_max,
36+
ScalarType dtype,
37+
executorch::aten::optional<ScalarType> out_dtype,
38+
Tensor& out);
39+
40+
Tensor& dequantize_per_token_out(
41+
const Tensor& input,
42+
const Tensor& scale,
43+
const Tensor& zero_points,
44+
int64_t quant_min,
45+
int64_t quant_max,
46+
ScalarType dtype,
47+
ScalarType out_dtype,
48+
Tensor& out);
49+
50+
// Wrapper function for dequantize_per_tensor_out without context
51+
Tensor& dequantize_per_tensor_out_no_context(
52+
const Tensor& input,
53+
double scale,
54+
int64_t zero_point,
55+
int64_t quant_min,
56+
int64_t quant_max,
57+
ScalarType dtype,
58+
executorch::aten::optional<ScalarType> out_dtype,
59+
Tensor& out) {
60+
return torch::executor::native::dequantize_per_tensor_out(
61+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
62+
}
63+
64+
// Wrapper function for dequantize_per_token_out without context
65+
Tensor& dequantize_per_token_out_no_context(
66+
const Tensor& input,
67+
const Tensor& scale,
68+
const Tensor& zero_points,
69+
int64_t quant_min,
70+
int64_t quant_max,
71+
ScalarType dtype,
72+
ScalarType out_dtype,
73+
Tensor& out) {
74+
return torch::executor::native::dequantize_per_token_out(
75+
input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out);
76+
}
77+
78+
// ATen wrapper for dequantize_per_tensor
79+
at::Tensor dequantize_per_tensor_aten(
80+
const at::Tensor& input,
81+
double scale,
82+
int64_t zero_point,
83+
int64_t quant_min,
84+
int64_t quant_max,
85+
at::ScalarType dtype,
86+
at::ScalarType out_dtype) {
87+
auto out = at::empty_like(input, out_dtype);
88+
// Convert at::ScalarType to executorch::ScalarType
89+
ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
90+
ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype);
91+
92+
executorch::aten::optional<ScalarType> opt_et_out_dtype(et_out_dtype);
93+
94+
WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7)
95+
(input,
96+
scale,
97+
zero_point,
98+
quant_min,
99+
quant_max,
100+
et_dtype,
101+
opt_et_out_dtype,
102+
out);
103+
return out;
104+
}
105+
106+
// ATen wrapper for dequantize_per_token
107+
at::Tensor dequantize_per_token_aten(
108+
const at::Tensor& input,
109+
const at::Tensor& scale,
110+
const at::Tensor& zero_points,
111+
int64_t quant_min,
112+
int64_t quant_max,
113+
at::ScalarType dtype,
114+
at::ScalarType out_dtype) {
115+
auto out = at::empty_like(input, out_dtype);
116+
// Convert at::ScalarType to executorch::ScalarType
117+
ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
118+
ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype);
119+
120+
WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7)
121+
(input,
122+
scale,
123+
zero_points,
124+
quant_min,
125+
quant_max,
126+
et_dtype,
127+
et_out_dtype,
128+
out);
129+
return out;
130+
}
131+
132+
} // namespace native
133+
} // namespace executor
134+
} // namespace torch
135+
136+
void check_dequantize_args(
137+
int64_t quant_min,
138+
int64_t quant_max,
139+
c10::ScalarType in_dtype,
140+
c10::ScalarType out_dtype) {
141+
using namespace vkcompute;
142+
143+
// Check that quant_min <= quant_max
144+
VK_CHECK_COND(
145+
quant_min <= quant_max,
146+
"quant_min must be <= quant_max, got quant_min: ",
147+
quant_min,
148+
" quant_max: ",
149+
quant_max);
150+
151+
// Check that input dtype is a quantized type
152+
switch (in_dtype) {
153+
case c10::kByte:
154+
case c10::kChar:
155+
case c10::kShort:
156+
case c10::kInt:
157+
case c10::kLong:
158+
break;
159+
default:
160+
VK_THROW(
161+
"Unsupported input dtype: ",
162+
scalar_type_name(in_dtype),
163+
" (",
164+
static_cast<int>(in_dtype),
165+
")");
166+
}
167+
168+
// Check that output dtype is a floating point type
169+
switch (out_dtype) {
170+
case c10::kHalf:
171+
case c10::kFloat:
172+
case c10::kDouble:
173+
break;
174+
default:
175+
VK_THROW(
176+
"Unsupported output dtype: ",
177+
scalar_type_name(out_dtype),
178+
" (",
179+
static_cast<int>(out_dtype),
180+
")");
181+
}
182+
}

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,15 @@ def define_common_targets(is_fbcode = False):
186186
"//executorch/extension/aten_util:aten_bridge",
187187
]
188188
)
189+
define_test_targets(
190+
"dequantize_test",
191+
extra_deps = [
192+
":test_utils",
193+
"//executorch/kernels/quantized/cpu:op_dequantize",
194+
"//executorch/extension/tensor:tensor",
195+
"//executorch/extension/aten_util:aten_bridge",
196+
]
197+
)
189198
define_test_targets(
190199
"linear_weight_int4_test",
191200
extra_deps = [

0 commit comments

Comments
 (0)