Skip to content

Commit 60c61ad

Browse files
v-Golubevchenhu-wang
authored andcommitted
WA: avoid precision enforcement if Brgemm A matrix has strided access
1 parent 3bd45e3 commit 60c61ad

File tree

1 file changed

+29
-17
lines changed

1 file changed

+29
-17
lines changed

src/plugins/intel_cpu/src/nodes/subgraph.cpp

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -553,25 +553,37 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
553553

554554
if (any_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
555555
subgraph_attrs->snippet->has_domain_sensitive_ops()) {
556-
// MatMul has to be decomposed to Brgemm operations,
557-
// and transposes on inputs/outputs should be fused in the brgemm before enforcement
556+
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(
557+
Place::After,
558+
ov::snippets::pass::FuseTransposeBrgemm,
559+
pass::EnforcePrecision,
560+
element::f32,
561+
context->getConfig().inferencePrecision,
562+
[](const std::shared_ptr<ov::Node>& op) {
563+
std::set<std::vector<ov::element::Type>> types;
564+
if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
565+
const auto& a_port =
566+
ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(op->input(0));
567+
// WA: We can't perform precision enforcement in case of strided access to A matrix:
568+
// snippets eltwise loops for precision conversion are generated by last 2 dims,
569+
// which are not [M, K] in case of strided access in brgemm A
570+
// There are no limitations for B matrix, since precision conversion is fused in BrgemmCopyB
571+
if (ov::snippets::utils::is_planar_layout(a_port->get_layout())) {
572+
if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) {
573+
types.insert({ov::element::f16, ov::element::f16});
574+
}
575+
if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) {
576+
types.insert({ov::element::bf16, ov::element::bf16});
577+
}
578+
}
579+
}
580+
return types;
581+
});
582+
// Note: EnforcePrecision might also eliminate Convert pairs (e.g. bf16->f32->bf16),
583+
// so FuseTransposeBrgemm has to be run after it as well
558584
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
559-
ov::snippets::pass::FuseTransposeBrgemm,
560585
pass::EnforcePrecision,
561-
element::f32,
562-
context->getConfig().inferencePrecision,
563-
[](const std::shared_ptr<ov::Node>& op) {
564-
std::set<std::vector<ov::element::Type>> types;
565-
if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
566-
if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) {
567-
types.insert({ov::element::f16, ov::element::f16});
568-
}
569-
if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) {
570-
types.insert({ov::element::bf16, ov::element::bf16});
571-
}
572-
}
573-
return types;
574-
});
586+
ov::snippets::pass::FuseTransposeBrgemm);
575587
}
576588

577589
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,

0 commit comments

Comments
 (0)