@@ -49,6 +49,32 @@ void test_dtype() {
49
49
EXPECT_TENSOR_EQ (out, expected);
50
50
}
51
51
52
+ template <ScalarType INPUT_DTYPE>
53
+ void test_input_dtype () {
54
+ TensorFactory<INPUT_DTYPE> tf_input;
55
+
56
+ Tensor input = tf_input.full ({3 , 5 }, 4 );
57
+ double scale = 0.5 ;
58
+ int64_t zero_point = 108 ;
59
+ int64_t quant_min = 0 ;
60
+ int64_t quant_max = 127 ;
61
+
62
+ TensorFactory<ScalarType::Char> tfo;
63
+ Tensor out = tfo.zeros ({3 , 5 });
64
+ // 4 / 0.5 + 108 = 116
65
+ Tensor expected = tfo.full ({3 , 5 }, 116 );
66
+ quantize_per_tensor_out (
67
+ input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
68
+
69
+ EXPECT_TENSOR_EQ (out, expected);
70
+ }
71
+
72
+ TEST (OpQuantizeOutTest, AllInputDtypesSupported) {
73
+ test_input_dtype<ScalarType::Float>();
74
+ test_input_dtype<ScalarType::Half>();
75
+ test_input_dtype<ScalarType::Double>();
76
+ }
77
+
52
78
TEST (OpQuantizeOutTest, AllDtypesSupported) {
53
79
test_dtype<ScalarType::Byte>();
54
80
test_dtype<ScalarType::Char>();
@@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) {
58
84
test_dtype<ScalarType::Int>();
59
85
}
60
86
87
+ TEST (OpQuantizeOutTest, DoubleInputTest) {
88
+ TensorFactory<ScalarType::Double> tf_double;
89
+
90
+ // Test with a more complex value that might have precision differences
91
+ Tensor input = tf_double.full ({2 , 3 }, 3.14159265359 );
92
+ double scale = 0.01 ;
93
+ int64_t zero_point = -100 ;
94
+ int64_t quant_min = 0 ;
95
+ int64_t quant_max = 255 ;
96
+
97
+ TensorFactory<ScalarType::Byte> tfo;
98
+ Tensor out = tfo.zeros ({2 , 3 });
99
+ // 3.14159265359 / 0.01 - 100 = 214.159265359
100
+ Tensor expected = tfo.full ({2 , 3 }, 214 );
101
+ quantize_per_tensor_out (
102
+ input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
103
+
104
+ EXPECT_TENSOR_EQ (out, expected);
105
+ }
106
+
107
+ TEST (OpQuantizeOutTest, HalfInputTest) {
108
+ TensorFactory<ScalarType::Half> tf_half;
109
+
110
+ Tensor input = tf_half.full ({2 , 3 }, 2.5 );
111
+ double scale = 0.5 ;
112
+ int64_t zero_point = 10 ;
113
+ int64_t quant_min = -128 ;
114
+ int64_t quant_max = 127 ;
115
+
116
+ TensorFactory<ScalarType::Char> tfo;
117
+ Tensor out = tfo.zeros ({2 , 3 });
118
+ // 2.5 / 0.5 + 10 = 15
119
+ Tensor expected = tfo.full ({2 , 3 }, 15 );
120
+ quantize_per_tensor_out (
121
+ input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
122
+
123
+ EXPECT_TENSOR_EQ (out, expected);
124
+ }
125
+
61
126
TEST (OpQuantizeOutTest, TensorArgOverload) {
62
127
TensorFactory<ScalarType::Float> tf_float;
63
128
TensorFactory<ScalarType::Double> tf_double;
0 commit comments