Skip to content

Commit 1ba8005

Browse files
authored
[GPU] Disabling redundant copying of constant weights (openvinotoolkit#18949)
1 parent 845bbfc commit 1ba8005

36 files changed

+754
-153
lines changed

src/common/transformations/include/transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp

+11
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace pass {
1414
class TRANSFORMATIONS_API EnableDecompressionConvertConstantFolding;
1515
class TRANSFORMATIONS_API DisableDecompressionConvertConstantFolding;
1616
class TRANSFORMATIONS_API KeepConstAndDecompression;
17+
class TRANSFORMATIONS_API KeepConstantsPrecisionAndAddConverts;
1718

1819
} // namespace pass
1920
} // namespace ov
@@ -47,3 +48,13 @@ class ov::pass::KeepConstAndDecompression : public MatcherPass {
4748
OPENVINO_RTTI("KeepConstAndDecompression", "0");
4849
KeepConstAndDecompression();
4950
};
51+
52+
/**
53+
* @ingroup ie_transformation_common_api
54+
* @brief Prevents Consts precision conversion and adds Convert with disabled ConstantFolding
55+
*/
56+
class ov::pass::KeepConstantsPrecisionAndAddConverts : public MatcherPass {
57+
public:
58+
OPENVINO_RTTI("KeepConstantsPrecisionAndAddConverts", "0");
59+
KeepConstantsPrecisionAndAddConverts();
60+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/core/node.hpp"
8+
#include "openvino/core/runtime_attribute.hpp"
9+
#include "transformations_visibility.hpp"
10+
11+
namespace ov {
12+
13+
TRANSFORMATIONS_API void enable_keep_const_precision(const std::shared_ptr<Node>& node);
14+
15+
TRANSFORMATIONS_API void disable_keep_const_precision(const std::shared_ptr<Node>& node);
16+
17+
TRANSFORMATIONS_API bool is_keep_const_precision(const std::shared_ptr<const Node>& node);
18+
19+
/**
20+
* @ingroup ie_runtime_attr_api
21+
* @brief KeepConstPrecision class represents runtime info attribute that marks a Constant
22+
* as prohibitted to fuse precision in ConvertPrecision
23+
*/
24+
class TRANSFORMATIONS_API KeepConstPrecision : public RuntimeAttribute {
25+
public:
26+
OPENVINO_RTTI("keep_const_precision", "0");
27+
28+
KeepConstPrecision() = default;
29+
30+
bool is_copyable() const override {
31+
return false;
32+
}
33+
};
34+
35+
} // namespace ov

src/common/transformations/include/transformations/rt_info/keep_fp16_const.hpp

-35
This file was deleted.

src/common/transformations/src/transformations/convert_precision.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#include "transformations/fp16_compression/mark_subgraphs_to_keep_in_mixed_precision.hpp"
2727
#include "transformations/rt_info/decompression.hpp"
2828
#include "transformations/rt_info/disable_fp16_compression.hpp"
29-
#include "transformations/rt_info/keep_fp16_const.hpp"
29+
#include "transformations/rt_info/keep_const_precision.hpp"
3030
#include "transformations/utils/utils.hpp"
3131

3232
using namespace ov;
@@ -1125,8 +1125,8 @@ std::shared_ptr<Node> convert_low_precisions_int(std::shared_ptr<opset4::Constan
11251125
bool fuse_type_to_constant(const std::shared_ptr<ov::Node>& node,
11261126
const precisions_map& precisions,
11271127
const std::vector<Input<Node>>& consumers) {
1128-
// Consts marked with disable_constant_folding should be kept in f16 until they reach the plugin
1129-
if (is_keep_fp16_const(node))
1128+
// Consts marked with is_keep_const_precision should be kept in their own precision until they reach the plugin
1129+
if (is_keep_const_precision(node))
11301130
return false;
11311131

11321132
auto from = node->get_element_type();

src/common/transformations/src/transformations/fp16_compression/mark_decompression_convert_constant_folding.cpp

+47-2
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
66

77
#include "itt.hpp"
8+
#include "openvino/core/rt_info.hpp"
89
#include "openvino/op/constant.hpp"
910
#include "openvino/op/convert.hpp"
1011
#include "openvino/op/matmul.hpp"
1112
#include "openvino/pass/pattern/op/wrap_type.hpp"
1213
#include "transformations/rt_info/decompression.hpp"
1314
#include "transformations/rt_info/disable_constant_folding.hpp"
15+
#include "transformations/rt_info/disable_fp16_compression.hpp"
1416
#include "transformations/rt_info/is_shape_subgraph.hpp"
15-
#include "transformations/rt_info/keep_fp16_const.hpp"
17+
#include "transformations/rt_info/keep_const_precision.hpp"
1618

1719
using namespace ov;
1820

@@ -67,10 +69,53 @@ pass::KeepConstAndDecompression::KeepConstAndDecompression() {
6769

6870
if (!is_type<ov::op::v0::Constant>(node->input_value(0).get_node_shared_ptr()))
6971
return false;
70-
enable_keep_fp16_const(node->input_value(0).get_node_shared_ptr());
72+
enable_keep_const_precision(node->input_value(0).get_node_shared_ptr());
7173

7274
return false;
7375
};
7476
auto m = std::make_shared<pattern::Matcher>(node_pattern, matcher_name);
7577
register_matcher(m, callback);
7678
}
79+
80+
pass::KeepConstantsPrecisionAndAddConverts::KeepConstantsPrecisionAndAddConverts() {
81+
MATCHER_SCOPE(KeepConstantsPrecisionAndAddConverts);
82+
auto const_pattern = pattern::wrap_type<ov::op::v0::Constant>();
83+
84+
matcher_pass_callback callback = [=](pattern::Matcher& m) {
85+
auto const_node = m.get_match_root();
86+
87+
if (transformation_callback(const_node)) {
88+
return false;
89+
}
90+
91+
enable_keep_const_precision(const_node);
92+
93+
const auto& constant_target_inputs = const_node->get_output_target_inputs(0);
94+
const auto& next_node = constant_target_inputs.begin()->get_node()->shared_from_this();
95+
if (is_type<ov::op::v0::Convert>(next_node)) {
96+
disable_constant_folding(next_node);
97+
if (is_decompression(next_node)) {
98+
unmark_as_decompression(next_node);
99+
}
100+
return true;
101+
}
102+
103+
auto convert = std::make_shared<ov::op::v0::Convert>(const_node, const_node->get_element_type());
104+
convert->set_friendly_name(const_node->get_friendly_name());
105+
106+
std::string postfix = const_node->get_element_type() == ov::element::f32 ? "compression" : "decompression";
107+
const_node->set_friendly_name(const_node->get_friendly_name() + "_postponed_" + postfix);
108+
109+
ov::copy_runtime_info(const_node, convert);
110+
disable_constant_folding(convert);
111+
112+
for (const auto& target_input : constant_target_inputs) {
113+
target_input.replace_source_output(convert);
114+
}
115+
116+
return true;
117+
};
118+
119+
auto m = std::make_shared<pass::pattern::Matcher>(const_pattern, matcher_name);
120+
this->register_matcher(m, callback);
121+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/rt_info/keep_const_precision.hpp"
6+
7+
void ov::enable_keep_const_precision(const std::shared_ptr<Node>& node) {
8+
auto& rt_info = node->get_rt_info();
9+
rt_info[KeepConstPrecision::get_type_info_static()] = KeepConstPrecision{};
10+
}
11+
12+
void ov::disable_keep_const_precision(const std::shared_ptr<Node>& node) {
13+
auto& rt_info = node->get_rt_info();
14+
rt_info.erase(KeepConstPrecision::get_type_info_static());
15+
}
16+
17+
bool ov::is_keep_const_precision(const std::shared_ptr<const Node>& node) {
18+
const auto& rt_info = node->get_rt_info();
19+
return rt_info.count(KeepConstPrecision::get_type_info_static());
20+
}

src/common/transformations/src/transformations/rt_info/keep_fp16_const.cpp

-20
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include <gtest/gtest.h>
6+
7+
#include <memory>
8+
9+
#include "common_test_utils/ov_test_utils.hpp"
10+
#include "openvino/opsets/opset1.hpp"
11+
#include "transformations/convert_precision.hpp"
12+
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
13+
#include "transformations/utils/utils.hpp"
14+
15+
using namespace testing;
16+
using namespace ov;
17+
using namespace ov::opset1;
18+
using const_node_ptr = const std::shared_ptr<const Node>;
19+
20+
TEST_F(TransformationTestsF, KeepConstantsPrecisionAndAddConvertsTestBase) {
21+
{
22+
auto input = std::make_shared<Parameter>(element::f32, Shape{3, 2, 2});
23+
auto weights = Constant::create(element::f32, Shape{1, 2, 2}, {1});
24+
auto matmul = std::make_shared<MatMul>(input, weights);
25+
26+
model = std::make_shared<Model>(NodeVector{matmul}, ParameterVector{input});
27+
28+
manager.register_pass<pass::KeepConstantsPrecisionAndAddConverts>();
29+
manager.get_pass_config()->set_callback<pass::KeepConstantsPrecisionAndAddConverts>(
30+
[](const_node_ptr& node) -> bool {
31+
auto next_node = node->get_output_target_inputs(0).begin()->get_node();
32+
if (is_type<op::v0::Convert>(next_node)) {
33+
next_node = next_node->get_output_target_inputs(0).begin()->get_node();
34+
}
35+
return !is_type<op::v0::MatMul>(next_node);
36+
});
37+
38+
const precisions_map precisions = {{element::f32, element::f16}};
39+
const type_to_fuse_map empty_fuse_map = {};
40+
const bool keep_precision_sensitive_in_fp32_1 = true;
41+
manager.register_pass<pass::ConvertPrecision>(precisions, empty_fuse_map, keep_precision_sensitive_in_fp32_1);
42+
}
43+
{
44+
auto input = std::make_shared<Parameter>(element::f16, Shape{3, 2, 2});
45+
auto weights = Constant::create(element::f32, Shape{1, 2, 2}, {1});
46+
auto convert_weights = std::make_shared<Convert>(weights, element::f16);
47+
auto matmul = std::make_shared<MatMul>(input, convert_weights);
48+
49+
model_ref = std::make_shared<Model>(NodeVector{matmul}, ParameterVector{input});
50+
}
51+
}
52+
53+
TEST_F(TransformationTestsF, KeepConstantsPrecisionAndAddConvertsTestWithCompressedConvert) {
54+
{
55+
auto input = std::make_shared<Parameter>(element::f16, Shape{3, 2, 2});
56+
auto weights = Constant::create(element::f32, Shape{1, 2, 2}, {1});
57+
auto convert_weights = std::make_shared<Convert>(weights, element::f16);
58+
mark_as_decompression(convert_weights);
59+
auto matmul = std::make_shared<MatMul>(input, convert_weights);
60+
61+
model = std::make_shared<Model>(NodeVector{matmul}, ParameterVector{input});
62+
63+
manager.register_pass<pass::KeepConstantsPrecisionAndAddConverts>();
64+
manager.get_pass_config()->set_callback<pass::KeepConstantsPrecisionAndAddConverts>(
65+
[](const_node_ptr& node) -> bool {
66+
auto next_node = node->get_output_target_inputs(0).begin()->get_node();
67+
if (is_type<op::v0::Convert>(next_node)) {
68+
next_node = next_node->get_output_target_inputs(0).begin()->get_node();
69+
}
70+
return !is_type<op::v0::MatMul>(next_node);
71+
});
72+
73+
const precisions_map precisions = {{element::f32, element::f16}};
74+
const type_to_fuse_map empty_fuse_map = {};
75+
const bool keep_precision_sensitive_in_fp32_1 = true;
76+
manager.register_pass<pass::ConvertPrecision>(precisions, empty_fuse_map, keep_precision_sensitive_in_fp32_1);
77+
}
78+
{
79+
auto input = std::make_shared<Parameter>(element::f16, Shape{3, 2, 2});
80+
auto weights = Constant::create(element::f32, Shape{1, 2, 2}, {1});
81+
auto convert_weights = std::make_shared<Convert>(weights, element::f16);
82+
auto matmul = std::make_shared<MatMul>(input, convert_weights);
83+
84+
model_ref = std::make_shared<Model>(NodeVector{matmul}, ParameterVector{input});
85+
}
86+
}

src/plugins/intel_gpu/include/intel_gpu/graph/program.hpp

+16-14
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,22 @@ class ICompilationContext;
3535
struct program {
3636
using ptr = std::shared_ptr<program>;
3737
using cptr = std::shared_ptr<const program>;
38-
friend class calculate_prior_boxes; // to be removed when possible
39-
friend class graph_initializations; // to be removed when possible
40-
friend class prepare_padding; // to be removed when possible
41-
friend class propagate_constants; // to be removed when possible
42-
friend class pre_replace_deconv; // to be removed when possible
43-
friend class prepare_primitive_fusing; // to be removed when possible
44-
friend class prepare_quantization; // to be removed when possible
45-
friend class prepare_conv_eltw_fusing; // to be removed when possible
46-
friend class reorder_inputs; // to be removed when possible
47-
friend class remove_redundant_reorders; // to be removed when possible
48-
friend class post_optimize_weights; // to be removed when possible
49-
friend class program_wrapper; // this class is intended to extend the interface of program for
50-
// the usage within tests_core_internal project only
51-
friend class prepare_primitive_fusing_through; // to be removed when possible
38+
friend class calculate_prior_boxes; // to be removed when possible
39+
friend class graph_initializations; // to be removed when possible
40+
friend class prepare_padding; // to be removed when possible
41+
friend class propagate_constants; // to be removed when possible
42+
friend class pre_replace_deconv; // to be removed when possible
43+
friend class prepare_primitive_fusing; // to be removed when possible
44+
friend class prepare_quantization; // to be removed when possible
45+
friend class prepare_conv_eltw_fusing; // to be removed when possible
46+
friend class reorder_inputs; // to be removed when possible
47+
friend class remove_redundant_reorders; // to be removed when possible
48+
friend class post_optimize_weights; // to be removed when possible
49+
friend class prepare_primitive_fusing_through; // to be removed when possible
50+
friend class reorder_transfer; // to be removed when possible
51+
friend class fuse_constant_transposes; // to be removed when possible
52+
friend class program_wrapper; // this class is intended to extend the interface of program for
53+
// the usage within tests_core_internal project only
5254
public:
5355
struct nodes_ordering {
5456
public:

src/plugins/intel_gpu/include/intel_gpu/runtime/format.hpp

+14-2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ struct format {
8080
bfvuwzyx, ///< 8d tensor
8181
yxfb, ///< batch first, feature and than spatials
8282
byxf, ///< used in bitmaps, input from user i.e b images of RGB format
83+
fbyx,
8384
fyxb, ///< format not used inside clDNN, but supported in reorder as extension
8485
bzyxf,
8586
byfx, ///< To be used when onednn gemm allows permute fusing in transformer network. Not for normal use from cldnn.
@@ -341,8 +342,9 @@ struct format {
341342
return (fmt == yxfb || fmt == byxf ||
342343
fmt == byfx || fmt == bxfy ||
343344
fmt == bfyx || fmt == fyxb ||
344-
fmt == bfzyx || fmt == bfwzyx ||
345-
fmt == bfuwzyx || fmt == bfvuwzyx);
345+
fmt == fbyx || fmt == bfzyx ||
346+
fmt == bfwzyx || fmt == bfuwzyx ||
347+
fmt == bfvuwzyx);
346348
}
347349

348350
static format get_default_format(size_t rank, bool is_weights = false, bool is_grouped = false);
@@ -352,6 +354,14 @@ struct format {
352354

353355
static const std::vector<std::pair<size_t, int>> per_axis_block_size(format fmt);
354356

357+
static format find_format(const std::vector<uint64_t>& order,
358+
const std::vector<std::pair<size_t, int>>& block_sizes,
359+
bool is_weights = false,
360+
bool is_grouped = false,
361+
bool is_image_2d = false,
362+
bool is_winograd = false,
363+
bool is_nv12 = false);
364+
355365
/// @brief Checks if @p format is of grouped type
356366
static bool is_grouped(type fmt) { return group_num(fmt) != 0; }
357367
/// @brief Checks if @p format is of image type
@@ -373,6 +383,8 @@ struct format {
373383
size_t spatial_num() const { return traits(value).spatial_num; }
374384
/// @brief Returns number of group dimensions.
375385
size_t group_num() const { return traits(value).group_num; }
386+
/// @brief Returns an order of dimensions.
387+
const std::vector<uint64_t>& dims_order() const { return traits(value)._order; }
376388
/// @brief Returns an order of dimensions in form of string.
377389
const std::string& order() const { return traits(value).order; }
378390
/// @brief Returns an internal orders of dimensions form of string.

0 commit comments

Comments
 (0)