|
46 | 46 |
|
47 | 47 | # include "emitters/snippets/x64/cpu_generator.hpp" |
48 | 48 | # include "executors/x64/subgraph.hpp" |
49 | | -# include "snippets/lowered/port_descriptor.hpp" |
50 | 49 | # include "snippets/op/brgemm.hpp" |
51 | | -# include "snippets/utils/utils.hpp" |
| 50 | +# include "snippets/pass/matmul_to_brgemm.hpp" |
52 | 51 | # include "transformations/snippets/x64/op/brgemm_utils.hpp" |
53 | 52 | #elif defined(OPENVINO_ARCH_ARM64) |
54 | 53 | # include <cpu/aarch64/cpu_isa_traits.hpp> |
|
87 | 86 | # include "snippets/lowered/pass/init_loops.hpp" |
88 | 87 | # include "snippets/lowered/pass/insert_buffers.hpp" |
89 | 88 | # include "snippets/lowered/pass/insert_loops.hpp" |
90 | | -# include "snippets/pass/fuse_transpose_brgemm.hpp" |
91 | 89 | # include "transformations/snippets/common/pass/enforce_precision.hpp" |
92 | 90 | # include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp" |
93 | 91 | # include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp" |
@@ -554,38 +552,34 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { |
554 | 552 |
|
555 | 553 | if (any_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && |
556 | 554 | subgraph_attrs->snippet->has_domain_sensitive_ops()) { |
557 | | - SNIPPETS_REGISTER_PASS_RELATIVE_X86_64( |
558 | | - Place::After, |
559 | | - ov::snippets::pass::FuseTransposeBrgemm, |
560 | | - pass::EnforcePrecision, |
561 | | - element::f32, |
562 | | - context->getConfig().inferencePrecision, |
563 | | - [](const std::shared_ptr<ov::Node>& op) { |
564 | | - std::set<std::vector<ov::element::Type>> types; |
565 | | - if (ov::is_type<ov::snippets::op::Brgemm>(op)) { |
566 | | - const auto& a_port = |
567 | | - ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(op->input(0)); |
568 | | - // WA: We can't perform precision enforcement in case of strided access to A matrix: |
569 | | - // snippets eltwise loops for precision conversion are generated by last 2 dims, |
570 | | - // which are not [M, K] in case of strided access in brgemm A |
571 | | - // There are no limitations for B matrix, since precision conversion is fused in BrgemmCopyB |
572 | | - // Ticket: 177121 |
573 | | - if (ov::snippets::utils::is_planar_layout(a_port->get_layout())) { |
574 | | - if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) { |
575 | | - types.insert({ov::element::f16, ov::element::f16}); |
576 | | - } |
577 | | - if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) { |
578 | | - types.insert({ov::element::bf16, ov::element::bf16}); |
579 | | - } |
580 | | - } |
581 | | - } |
582 | | - return types; |
583 | | - }); |
584 | | - // Note: EnforcePrecision might also eliminate Convert pairs (e.g. bf16->f32->bf16), |
585 | | - // so FuseTransposeBrgemm has to be run after it as well |
| 555 | + // enforce BF16 precisions to supported operations |
| 556 | + // MatMul has to be decomposed to Brgemm operations before enforcement |
| 557 | + // Notes: |
| 558 | + // - MatMul decomposition will be run later again for case if BF16 enforcement is not happened |
| 559 | + // - `MatMulToBrgemm` pass fuse `transpose_a` and `transpose_b` from MatMul to inputs of Brgemm as layouts. |
| 560 | + // These layouts are resized to ranks of input shapes. But since `Canonicalization` might |
| 561 | + // reshape shapes, the pass `MatMulToBrgemm` should be after the pass `Canonicalization` to |
| 562 | + // fuse layouts with ranks aligned with updated shapes after RankNormalization insertions. |
| 563 | + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, |
| 564 | + ov::snippets::pass::Canonicalization, |
| 565 | + ov::snippets::pass::MatMulToBrgemm); |
586 | 566 | SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, |
| 567 | + ov::snippets::pass::MatMulToBrgemm, |
587 | 568 | pass::EnforcePrecision, |
588 | | - ov::snippets::pass::FuseTransposeBrgemm); |
| 569 | + element::f32, |
| 570 | + context->getConfig().inferencePrecision, |
| 571 | + [](const std::shared_ptr<ov::Node>& op) { |
| 572 | + std::set<std::vector<ov::element::Type>> types; |
| 573 | + if (ov::is_type<ov::snippets::op::Brgemm>(op)) { |
| 574 | + if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) { |
| 575 | + types.insert({ov::element::f16, ov::element::f16}); |
| 576 | + } |
| 577 | + if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) { |
| 578 | + types.insert({ov::element::bf16, ov::element::bf16}); |
| 579 | + } |
| 580 | + } |
| 581 | + return types; |
| 582 | + }); |
589 | 583 | } |
590 | 584 |
|
591 | 585 | SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, |
|
0 commit comments