Skip to content

Commit 1bb6517

Browse files
vishakha-nervanasayantan-nervana
authored andcommitted
Vishakh1/round mode br (#353)
1 parent 605c82e commit 1bb6517

File tree

2 files changed

+77
-7
lines changed

2 files changed

+77
-7
lines changed

ngraph_bridge/ngraph_builder.cc

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,8 +3327,17 @@ static Status TranslateQuantizeAndDequantizeV2Op(
33273327
op->name(), ng_r_et, ng::Shape(), std::vector<float>({scale}));
33283328
auto ng_offset = ConstructNgNode<ng::op::Constant>(
33293329
op->name(), ng_q_et, ng::Shape(), std::vector<int>({0}));
3330-
ng::op::Quantize::RoundMode ng_round_mode =
3331-
ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY;
3330+
ng::op::Quantize::RoundMode ng_round_mode;
3331+
string round_mode_string;
3332+
TF_RETURN_IF_ERROR(
3333+
GetNodeAttr(op->attrs(), "round_mode", &round_mode_string));
3334+
if (round_mode_string == "HALF_UP") {
3335+
ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_UPWARD;
3336+
} else if (round_mode_string == "HALF_TO_EVEN") {
3337+
ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
3338+
} else {
3339+
return errors::Internal("Tensorflow Rounding Mode not supported by Ngraph");
3340+
}
33323341
auto ng_quant = ConstructNgNode<ng::op::Quantize>(
33333342
op->name(), ng_input, ng_scale, ng_offset, ng_q_et, ng::AxisSet(),
33343343
ng_round_mode);
@@ -3608,10 +3617,17 @@ static Status TranslateQuantizeV2Op(const Node* op,
36083617
ng::element::Type ng_et;
36093618
TF_RETURN_IF_ERROR(TFDataTypeToNGraphElementType(dtype, &ng_et));
36103619

3611-
// TODO: Only RoundMode = ROUND_NEAREST_TOWARD_EVEN is supported, for now.
3612-
// Support other modes later
3613-
ng::op::Quantize::RoundMode ng_round_mode =
3614-
ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
3620+
ng::op::Quantize::RoundMode ng_round_mode;
3621+
string round_mode_string;
3622+
TF_RETURN_IF_ERROR(
3623+
GetNodeAttr(op->attrs(), "round_mode", &round_mode_string));
3624+
if (round_mode_string == "HALF_UP") {
3625+
ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_UPWARD;
3626+
} else if (round_mode_string == "HALF_TO_EVEN") {
3627+
ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
3628+
} else {
3629+
return errors::Internal("Tensorflow Rounding Mode not supported by Ngraph");
3630+
}
36153631

36163632
auto ng_node = ng::builder::ScaledQuantize(ng_input, ng_min, ng_max, ng_et,
36173633
ng::AxisSet(), ng_round_mode);

test/test_array_ops.cpp

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,61 @@ TEST(ArrayOps, QuantizeAndDequantizeV2x8xtruexfalse) {
797797
output_datatypes, sess_run_fetchoutputs);
798798

799799
opexecuter.RunTest();
800-
} // end of test op QuantizeAndDequantizeV2x8xtruexfalse
800+
}
801+
802+
TEST(ArrayOps, QuantizeAndDequantizeV2RoundingMode1) {
803+
Scope root = Scope::NewRootScope();
804+
int dim1 = 2;
805+
int dim2 = 3;
806+
807+
Tensor A(DT_FLOAT, TensorShape({dim1, dim2}));
808+
AssignInputValues<float>(A, {0.9, 3.4, 2.6, 5.4, 4.2, 4.5});
809+
810+
auto attrs = ops::QuantizeAndDequantizeV2::Attrs();
811+
attrs.num_bits_ = 8;
812+
attrs.range_given_ = true;
813+
attrs.signed_input_ = true;
814+
attrs.round_mode_ = "HALF_UP";
815+
816+
vector<int> static_input_indexes = {1, 2};
817+
ops::QuantizeAndDequantizeV2 R =
818+
ops::QuantizeAndDequantizeV2(root, A, 0.0f, 127.0f, attrs);
819+
820+
vector<DataType> output_datatypes = {DT_FLOAT};
821+
822+
std::vector<Output> sess_run_fetchoutputs = {R.output};
823+
OpExecuter opexecuter(root, "QuantizeAndDequantizeV2", static_input_indexes,
824+
output_datatypes, sess_run_fetchoutputs);
825+
826+
opexecuter.RunTest();
827+
}
828+
829+
TEST(ArrayOps, QuantizeAndDequantizeV2RoundingMode2) {
830+
Scope root = Scope::NewRootScope();
831+
int dim1 = 2;
832+
int dim2 = 3;
833+
834+
Tensor A(DT_FLOAT, TensorShape({dim1, dim2}));
835+
AssignInputValues<float>(A, {0.9, 3.4, 2.6, 5.4, 4.2, 4.5});
836+
837+
auto attrs = ops::QuantizeAndDequantizeV2::Attrs();
838+
attrs.num_bits_ = 8;
839+
attrs.range_given_ = true;
840+
attrs.signed_input_ = true;
841+
attrs.round_mode_ = "HALF_TO_EVEN";
842+
843+
vector<int> static_input_indexes = {1, 2};
844+
ops::QuantizeAndDequantizeV2 R =
845+
ops::QuantizeAndDequantizeV2(root, A, 0.0f, 127.0f, attrs);
846+
847+
vector<DataType> output_datatypes = {DT_FLOAT};
848+
849+
std::vector<Output> sess_run_fetchoutputs = {R.output};
850+
OpExecuter opexecuter(root, "QuantizeAndDequantizeV2", static_input_indexes,
851+
output_datatypes, sess_run_fetchoutputs);
852+
853+
opexecuter.RunTest();
854+
} // end of test op QuantizeAndDequantizeV2x8xtruextrue
801855

802856
// CPU only supports QuantizedConcat with DT_QINT32 and DT_QUINT8
803857
TEST(ArrayOps, QuantizedConcat) {

0 commit comments

Comments
 (0)