@@ -843,16 +843,6 @@ ov::Tensor generate(const std::shared_ptr<ov::op::v10::IsFinite>& node,
843843 return tensor;
844844}
845845
846- ov::Tensor generate (const std::shared_ptr<ov::op::v10::IsNaN>& node,
847- size_t port,
848- const ov::element::Type& elemType,
849- const ov::Shape& targetShape,
850- std::shared_ptr<InputGenerateData> inGenRangeData = nullptr ) {
851- ov::Tensor tensor{elemType, targetShape};
852- comparison::fill_tensor (tensor);
853- return tensor;
854- }
855-
856846namespace is_inf {
857847template <typename T>
858848void fill_tensor (ov::Tensor& tensor) {
@@ -877,6 +867,20 @@ void fill_tensor(ov::Tensor& tensor) {
877867}
878868} // namespace is_inf
879869
870+ ov::Tensor generate (const std::shared_ptr<ov::op::v10::IsNaN>& node,
871+ size_t port,
872+ const ov::element::Type& elemType,
873+ const ov::Shape& targetShape,
874+ std::shared_ptr<InputGenerateData> inGenRangeData = nullptr ) {
875+ ov::Tensor tensor{elemType, targetShape};
876+ if (elemType == ov::element::f16 ) {
877+ is_inf::fill_tensor<ov::float16>(tensor);
878+ } else {
879+ is_inf::fill_tensor<float >(tensor);
880+ }
881+ return tensor;
882+ }
883+
880884ov::Tensor generate (const std::shared_ptr<ov::op::v10::IsInf>& node,
881885 size_t port,
882886 const ov::element::Type& elemType,
0 commit comments