|
| 1 | +// Copyright (C) 2023 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#include <gtest/gtest.h> |
| 6 | + |
| 7 | +#include <memory> |
| 8 | + |
| 9 | +#include "common_test_utils/ov_test_utils.hpp" |
| 10 | +#include "openvino/opsets/opset1.hpp" |
| 11 | +#include "transformations/convert_precision.hpp" |
| 12 | +#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp" |
| 13 | +#include "transformations/utils/utils.hpp" |
| 14 | + |
| 15 | +using namespace testing; |
| 16 | +using namespace ov; |
| 17 | +using namespace ov::opset1; |
| 18 | +using const_node_ptr = const std::shared_ptr<const Node>; |
| 19 | + |
| 20 | +TEST_F(TransformationTestsF, KeepConstantsPrecisionAndAddConvertsTestBase) { |
| 21 | + { |
| 22 | + auto input = std::make_shared<Parameter>(element::f32, Shape{3, 2, 2}); |
| 23 | + auto weights = Constant::create(element::f32, Shape{1, 2, 2}, {1}); |
| 24 | + auto matmul = std::make_shared<MatMul>(input, weights); |
| 25 | + |
| 26 | + model = std::make_shared<Model>(NodeVector{matmul}, ParameterVector{input}); |
| 27 | + |
| 28 | + manager.register_pass<pass::KeepConstantsPrecisionAndAddConverts>(); |
| 29 | + manager.get_pass_config()->set_callback<pass::KeepConstantsPrecisionAndAddConverts>( |
| 30 | + [](const_node_ptr& node) -> bool { |
| 31 | + auto next_node = node->get_output_target_inputs(0).begin()->get_node(); |
| 32 | + if (is_type<op::v0::Convert>(next_node)) { |
| 33 | + next_node = next_node->get_output_target_inputs(0).begin()->get_node(); |
| 34 | + } |
| 35 | + return !is_type<op::v0::MatMul>(next_node); |
| 36 | + }); |
| 37 | + |
| 38 | + const precisions_map precisions = {{element::f32, element::f16}}; |
| 39 | + const type_to_fuse_map empty_fuse_map = {}; |
| 40 | + const bool keep_precision_sensitive_in_fp32_1 = true; |
| 41 | + manager.register_pass<pass::ConvertPrecision>(precisions, empty_fuse_map, keep_precision_sensitive_in_fp32_1); |
| 42 | + } |
| 43 | + { |
| 44 | + auto input = std::make_shared<Parameter>(element::f16, Shape{3, 2, 2}); |
| 45 | + auto weights = Constant::create(element::f32, Shape{1, 2, 2}, {1}); |
| 46 | + auto convert_weights = std::make_shared<Convert>(weights, element::f16); |
| 47 | + auto matmul = std::make_shared<MatMul>(input, convert_weights); |
| 48 | + |
| 49 | + model_ref = std::make_shared<Model>(NodeVector{matmul}, ParameterVector{input}); |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +TEST_F(TransformationTestsF, KeepConstantsPrecisionAndAddConvertsTestWithCompressedConvert) { |
| 54 | + { |
| 55 | + auto input = std::make_shared<Parameter>(element::f16, Shape{3, 2, 2}); |
| 56 | + auto weights = Constant::create(element::f32, Shape{1, 2, 2}, {1}); |
| 57 | + auto convert_weights = std::make_shared<Convert>(weights, element::f16); |
| 58 | + mark_as_decompression(convert_weights); |
| 59 | + auto matmul = std::make_shared<MatMul>(input, convert_weights); |
| 60 | + |
| 61 | + model = std::make_shared<Model>(NodeVector{matmul}, ParameterVector{input}); |
| 62 | + |
| 63 | + manager.register_pass<pass::KeepConstantsPrecisionAndAddConverts>(); |
| 64 | + manager.get_pass_config()->set_callback<pass::KeepConstantsPrecisionAndAddConverts>( |
| 65 | + [](const_node_ptr& node) -> bool { |
| 66 | + auto next_node = node->get_output_target_inputs(0).begin()->get_node(); |
| 67 | + if (is_type<op::v0::Convert>(next_node)) { |
| 68 | + next_node = next_node->get_output_target_inputs(0).begin()->get_node(); |
| 69 | + } |
| 70 | + return !is_type<op::v0::MatMul>(next_node); |
| 71 | + }); |
| 72 | + |
| 73 | + const precisions_map precisions = {{element::f32, element::f16}}; |
| 74 | + const type_to_fuse_map empty_fuse_map = {}; |
| 75 | + const bool keep_precision_sensitive_in_fp32_1 = true; |
| 76 | + manager.register_pass<pass::ConvertPrecision>(precisions, empty_fuse_map, keep_precision_sensitive_in_fp32_1); |
| 77 | + } |
| 78 | + { |
| 79 | + auto input = std::make_shared<Parameter>(element::f16, Shape{3, 2, 2}); |
| 80 | + auto weights = Constant::create(element::f32, Shape{1, 2, 2}, {1}); |
| 81 | + auto convert_weights = std::make_shared<Convert>(weights, element::f16); |
| 82 | + auto matmul = std::make_shared<MatMul>(input, convert_weights); |
| 83 | + |
| 84 | + model_ref = std::make_shared<Model>(NodeVector{matmul}, ParameterVector{input}); |
| 85 | + } |
| 86 | +} |
0 commit comments