Skip to content

Commit b429e0a

Browse files
authored
Revert "[Snippets][CPU] Precision enforcement fix (#32914)"
This reverts commit 792ddf3.
1 parent 792ddf3 commit b429e0a

File tree

5 files changed

+49
-60
lines changed

5 files changed

+49
-60
lines changed

src/common/snippets/src/pass/align_element_types.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
4747
std::shared_ptr<ov::Node> consumer = shape_infer_leaf ? shape_infer_leaf : results[i];
4848
auto parent_output = consumer->get_input_source_output(0);
4949

50+
// Snippets supports Transpose only after Parameter or before Result nodes
51+
// So we have to insert Convert before Transpose (if there is) on Subgraph outputs
52+
const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent_output.get_node_shared_ptr());
53+
if (transpose) {
54+
OPENVINO_ASSERT(
55+
parent_output.get_target_inputs().size() == 1,
56+
"If Result has Transpose on input, this Result must be single consumer of the Transpose");
57+
parent_output = transpose->get_input_source_output(0);
58+
consumer = transpose;
59+
}
60+
5061
// If there is already Convert[needed_in_type->original_type] and this node has only one consumer, we can
5162
// remove the Convert, since the sequence existing Convert[needed_in_type->original_type] -> new
5263
// Convert[original_type->needed_in_type] is redundant
@@ -70,6 +81,9 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
7081

7182
consumer->set_argument(0, convert);
7283
consumer->validate_and_infer_types();
84+
if (transpose) {
85+
results[i]->validate_and_infer_types();
86+
}
7387
is_modified = true;
7488
}
7589
}

src/plugins/intel_cpu/src/nodes/subgraph.cpp

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@
4646

4747
# include "emitters/snippets/x64/cpu_generator.hpp"
4848
# include "executors/x64/subgraph.hpp"
49-
# include "snippets/lowered/port_descriptor.hpp"
5049
# include "snippets/op/brgemm.hpp"
51-
# include "snippets/utils/utils.hpp"
50+
# include "snippets/pass/matmul_to_brgemm.hpp"
5251
# include "transformations/snippets/x64/op/brgemm_utils.hpp"
5352
#elif defined(OPENVINO_ARCH_ARM64)
5453
# include <cpu/aarch64/cpu_isa_traits.hpp>
@@ -87,7 +86,6 @@
8786
# include "snippets/lowered/pass/init_loops.hpp"
8887
# include "snippets/lowered/pass/insert_buffers.hpp"
8988
# include "snippets/lowered/pass/insert_loops.hpp"
90-
# include "snippets/pass/fuse_transpose_brgemm.hpp"
9189
# include "transformations/snippets/common/pass/enforce_precision.hpp"
9290
# include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp"
9391
# include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp"
@@ -554,38 +552,34 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
554552

555553
if (any_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
556554
subgraph_attrs->snippet->has_domain_sensitive_ops()) {
557-
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(
558-
Place::After,
559-
ov::snippets::pass::FuseTransposeBrgemm,
560-
pass::EnforcePrecision,
561-
element::f32,
562-
context->getConfig().inferencePrecision,
563-
[](const std::shared_ptr<ov::Node>& op) {
564-
std::set<std::vector<ov::element::Type>> types;
565-
if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
566-
const auto& a_port =
567-
ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(op->input(0));
568-
// WA: We can't perform precision enforcement in case of strided access to A matrix:
569-
// snippets eltwise loops for precision conversion are generated by last 2 dims,
570-
// which are not [M, K] in case of strided access in brgemm A
571-
// There are no limitations for B matrix, since precision conversion is fused in BrgemmCopyB
572-
// Ticket: 177121
573-
if (ov::snippets::utils::is_planar_layout(a_port->get_layout())) {
574-
if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) {
575-
types.insert({ov::element::f16, ov::element::f16});
576-
}
577-
if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) {
578-
types.insert({ov::element::bf16, ov::element::bf16});
579-
}
580-
}
581-
}
582-
return types;
583-
});
584-
// Note: EnforcePrecision might also eliminate Convert pairs (e.g. bf16->f32->bf16),
585-
// so FuseTransposeBrgemm has to be run after it as well
555+
// enforce BF16 precisions to supported operations
556+
// MatMul has to be decomposed to Brgemm operations before enforcement
557+
// Notes:
558+
// - MatMul decomposition will be run later again for case if BF16 enforcement is not happened
559+
// - `MatMulToBrgemm` pass fuse `transpose_a` and `transpose_b` from MatMul to inputs of Brgemm as layouts.
560+
// These layouts are resized to ranks of input shapes. But since `Canonicalization` might
561+
// reshape shapes, the pass `MatMulToBrgemm` should be after the pass `Canonicalization` to
562+
// fuse layouts with ranks aligned with updated shapes after RankNormalization insertions.
563+
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
564+
ov::snippets::pass::Canonicalization,
565+
ov::snippets::pass::MatMulToBrgemm);
586566
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
567+
ov::snippets::pass::MatMulToBrgemm,
587568
pass::EnforcePrecision,
588-
ov::snippets::pass::FuseTransposeBrgemm);
569+
element::f32,
570+
context->getConfig().inferencePrecision,
571+
[](const std::shared_ptr<ov::Node>& op) {
572+
std::set<std::vector<ov::element::Type>> types;
573+
if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
574+
if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) {
575+
types.insert({ov::element::f16, ov::element::f16});
576+
}
577+
if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) {
578+
types.insert({ov::element::bf16, ov::element::bf16});
579+
}
580+
}
581+
return types;
582+
});
589583
}
590584

591585
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,

src/plugins/intel_cpu/src/transformations/snippets/x64/pass/eliminate_brgemm_copy_b.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ bool pass::EliminateBrgemmCopyB::run_on_model(const std::shared_ptr<ov::Model>&
6666
// Since repacking is moved out of Subgraph body,
6767
// the rest weights subgraph must be updated with precision after repacking
6868
param->set_element_type(copy_b_node->get_config().wei_dt());
69-
// Note: validation is called manually since set_element_type doesn't update output element type
70-
param->validate_and_infer_types();
7169
if (pattern_map.count(m_rank_norm)) {
7270
pattern_map.at(m_rank_norm).get_node_shared_ptr()->validate_and_infer_types();
7371
}

src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
160160
MHA,
161161
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
162162
::testing::ValuesIn(precision_bf16_if_supported(4)),
163-
::testing::Values(ov::element::bf16),
163+
::testing::Values(ov::element::f32),
164164
::testing::Values(false),
165165
::testing::Values(MHA::default_thread_count),
166-
::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output
167-
::testing::Values(2), // decomposed Transpose + MHA
166+
::testing::Values(8), // decomposed Transpose + MHA + 5 Converts + 1 Transpose on output
167+
::testing::Values(6), // MHA + 5 Converts on inputs and output
168168
::testing::Values(ov::test::utils::DEVICE_CPU),
169169
::testing::Values(CPUTestUtils::empty_plugin_config)),
170170
MHA::getTestCaseName);
@@ -182,19 +182,6 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
182182
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
183183
MHA::getTestCaseName);
184184

185-
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16_f32_in_prc,
186-
MHA,
187-
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
188-
::testing::ValuesIn(precision_f32(4)),
189-
::testing::Values(ov::element::f32),
190-
::testing::ValuesIn({false}),
191-
::testing::Values(MHA::default_thread_count),
192-
::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output
193-
::testing::Values(2), // decomposed Transpose + MHA
194-
::testing::Values(ov::test::utils::DEVICE_CPU),
195-
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
196-
MHA::getTestCaseName);
197-
198185
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_Without_Multiply,
199186
MHA,
200187
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),

src/tests/functional/plugin/shared/src/snippets/mha.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@ void MHABase::generate_inputs(const std::vector<ov::Shape>& targetInputStaticSha
2828
const auto& model_input = model_inputs[i];
2929
ov::Tensor tensor;
3030
ov::test::utils::InputGenerateData in_data;
31-
const bool bf16_precision =
32-
configuration.at(ov::hint::inference_precision.name()).as<ov::element::Type>() == ov::element::bf16 ||
33-
model_input.get_element_type() == ov::element::bf16;
3431
// To avoid big relative errors in the vicinity of zero, only positive values are generated for bf16 precision
35-
in_data.start_from = bf16_precision ? 0 : -1;
32+
in_data.start_from = model_input.get_element_type() == ov::element::bf16 ? 0 : -1;
3633
in_data.range = 2;
3734
in_data.resolution = 256;
3835
tensor =
@@ -58,17 +55,16 @@ void MHABase::SetUp() {
5855
setInferenceType(prc);
5956
}
6057

61-
void MHABase::init_thresholds() {
58+
void MHABase::init_thresholds() {
6259
// Note: Libxsmm calculates Exp in a slightly different way, so the abs values might differ a bit. Ticket: 130699
6360
#ifdef SNIPPETS_LIBXSMM_TPP
6461
abs_threshold = 1e-6;
6562
#endif
66-
auto infer_precision = configuration.at(ov::hint::inference_precision.name()).as<ov::element::Type>();
67-
if (infer_precision == ov::element::bf16)
63+
if (inType == ov::element::bf16)
6864
rel_threshold = 0.05f;
69-
if (infer_precision == ov::element::f16)
65+
if (inType == ov::element::f16)
7066
abs_threshold = 2e-2;
71-
}
67+
}
7268

7369
std::string MHA::getTestCaseName(const testing::TestParamInfo<ov::test::snippets::MHAParams>& obj) {
7470
const auto& [input_shapes,

0 commit comments

Comments
 (0)