@@ -553,25 +553,37 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
553553
554554 if (any_of (context->getConfig ().inferencePrecision , ov::element::bf16 , ov::element::f16 ) &&
555555 subgraph_attrs->snippet ->has_domain_sensitive_ops ()) {
556- // MatMul has to be decomposed to Brgemm operations,
557- // and transposes on inputs/outputs should be fused in the brgemm before enforcement
556+ SNIPPETS_REGISTER_PASS_RELATIVE_X86_64 (
557+ Place::After,
558+ ov::snippets::pass::FuseTransposeBrgemm,
559+ pass::EnforcePrecision,
560+ element::f32 ,
561+ context->getConfig ().inferencePrecision ,
562+ [](const std::shared_ptr<ov::Node>& op) {
563+ std::set<std::vector<ov::element::Type>> types;
564+ if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
565+ const auto & a_port =
566+ ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr (op->input (0 ));
567+ // WA: We can't perform precision enforcement in case of strided access to A matrix:
568+ // snippets eltwise loops for precision conversion are generated by last 2 dims,
569+ // which are not [M, K] in case of strided access in brgemm A
570+ // There are no limitations for B matrix, since precision conversion is fused in BrgemmCopyB
571+ if (ov::snippets::utils::is_planar_layout (a_port->get_layout ())) {
572+ if (ov::intel_cpu::brgemm_utils::is_fp16_supported ()) {
573+ types.insert ({ov::element::f16 , ov::element::f16 });
574+ }
575+ if (ov::intel_cpu::brgemm_utils::is_bf16_supported ()) {
576+ types.insert ({ov::element::bf16 , ov::element::bf16 });
577+ }
578+ }
579+ }
580+ return types;
581+ });
582+ // Note: EnforcePrecision might also eliminate Convert pairs (e.g. bf16->f32->bf16),
583+ // so FuseTransposeBrgemm has to be run after it as well
558584 SNIPPETS_REGISTER_PASS_RELATIVE_X86_64 (Place::After,
559- ov::snippets::pass::FuseTransposeBrgemm,
560585 pass::EnforcePrecision,
561- element::f32 ,
562- context->getConfig ().inferencePrecision ,
563- [](const std::shared_ptr<ov::Node>& op) {
564- std::set<std::vector<ov::element::Type>> types;
565- if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
566- if (ov::intel_cpu::brgemm_utils::is_fp16_supported ()) {
567- types.insert ({ov::element::f16 , ov::element::f16 });
568- }
569- if (ov::intel_cpu::brgemm_utils::is_bf16_supported ()) {
570- types.insert ({ov::element::bf16 , ov::element::bf16 });
571- }
572- }
573- return types;
574- });
586+ ov::snippets::pass::FuseTransposeBrgemm);
575587 }
576588
577589 SNIPPETS_REGISTER_PASS_RELATIVE_X86_64 (Place::Before,
0 commit comments