Skip to content

Commit 1f024fc

Browse files
authored
[ET] enabling half dtype input for quantization
Differential Revision: D76053764 Pull Request resolved: #11479
1 parent 4599518 commit 1f024fc

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ Tensor& quantize_per_tensor_out(
150150
break;
151151

152152
switch (input.scalar_type()) {
153-
ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
153+
ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE);
154154
default:
155155
ET_CHECK_MSG(
156156
false,
@@ -346,7 +346,7 @@ Tensor& quantize_per_channel_out(
346346
break;
347347

348348
switch (input.scalar_type()) {
349-
ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
349+
ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE);
350350
default:
351351
ET_CHECK_MSG(
352352
false,

kernels/quantized/test/op_quantize_test.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,32 @@ void test_dtype() {
4949
EXPECT_TENSOR_EQ(out, expected);
5050
}
5151

52+
template <ScalarType INPUT_DTYPE>
53+
void test_input_dtype() {
54+
TensorFactory<INPUT_DTYPE> tf_input;
55+
56+
Tensor input = tf_input.full({3, 5}, 4);
57+
double scale = 0.5;
58+
int64_t zero_point = 108;
59+
int64_t quant_min = 0;
60+
int64_t quant_max = 127;
61+
62+
TensorFactory<ScalarType::Char> tfo;
63+
Tensor out = tfo.zeros({3, 5});
64+
// 4 / 0.5 + 108 = 116
65+
Tensor expected = tfo.full({3, 5}, 116);
66+
quantize_per_tensor_out(
67+
input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
68+
69+
EXPECT_TENSOR_EQ(out, expected);
70+
}
71+
72+
TEST(OpQuantizeOutTest, AllInputDtypesSupported) {
73+
test_input_dtype<ScalarType::Float>();
74+
test_input_dtype<ScalarType::Half>();
75+
test_input_dtype<ScalarType::Double>();
76+
}
77+
5278
TEST(OpQuantizeOutTest, AllDtypesSupported) {
5379
test_dtype<ScalarType::Byte>();
5480
test_dtype<ScalarType::Char>();
@@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) {
5884
test_dtype<ScalarType::Int>();
5985
}
6086

87+
TEST(OpQuantizeOutTest, DoubleInputTest) {
88+
TensorFactory<ScalarType::Double> tf_double;
89+
90+
// Test with a more complex value that might have precision differences
91+
Tensor input = tf_double.full({2, 3}, 3.14159265359);
92+
double scale = 0.01;
93+
int64_t zero_point = -100;
94+
int64_t quant_min = 0;
95+
int64_t quant_max = 255;
96+
97+
TensorFactory<ScalarType::Byte> tfo;
98+
Tensor out = tfo.zeros({2, 3});
99+
// 3.14159265359 / 0.01 - 100 = 214.159265359
100+
Tensor expected = tfo.full({2, 3}, 214);
101+
quantize_per_tensor_out(
102+
input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
103+
104+
EXPECT_TENSOR_EQ(out, expected);
105+
}
106+
107+
TEST(OpQuantizeOutTest, HalfInputTest) {
108+
TensorFactory<ScalarType::Half> tf_half;
109+
110+
Tensor input = tf_half.full({2, 3}, 2.5);
111+
double scale = 0.5;
112+
int64_t zero_point = 10;
113+
int64_t quant_min = -128;
114+
int64_t quant_max = 127;
115+
116+
TensorFactory<ScalarType::Char> tfo;
117+
Tensor out = tfo.zeros({2, 3});
118+
// 2.5 / 0.5 + 10 = 15
119+
Tensor expected = tfo.full({2, 3}, 15);
120+
quantize_per_tensor_out(
121+
input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
122+
123+
EXPECT_TENSOR_EQ(out, expected);
124+
}
125+
61126
TEST(OpQuantizeOutTest, TensorArgOverload) {
62127
TensorFactory<ScalarType::Float> tf_float;
63128
TensorFactory<ScalarType::Double> tf_double;

0 commit comments

Comments
 (0)