diff --git a/src/common/snippets/include/snippets/lowered/expression.hpp b/src/common/snippets/include/snippets/lowered/expression.hpp index f56949474d8e45..597b1154f6983f 100644 --- a/src/common/snippets/include/snippets/lowered/expression.hpp +++ b/src/common/snippets/include/snippets/lowered/expression.hpp @@ -71,6 +71,8 @@ class Expression : public std::enable_shared_from_this { return m_output_port_descriptors; } + ExpressionPtr get_input_expr_ptr(size_t i) const; + size_t get_input_count() const { return m_input_port_connectors.size(); } diff --git a/src/common/snippets/src/lowered/expression.cpp b/src/common/snippets/src/lowered/expression.cpp index 15fde35f7039cd..61f16506490a0f 100644 --- a/src/common/snippets/src/lowered/expression.cpp +++ b/src/common/snippets/src/lowered/expression.cpp @@ -47,27 +47,31 @@ Expression::Expression(const std::shared_ptr& n, } const PortConnectorPtr& Expression::get_input_port_connector(size_t i) const { - assert(i < m_input_port_connectors.size() && - "Failed to get input port connector: target input port must be less than input count!"); + OPENVINO_ASSERT(i < m_input_port_connectors.size() && + "Failed to get input port connector: target input port must be less than input count!"); return m_input_port_connectors[i]; } const PortConnectorPtr& Expression::get_output_port_connector(size_t i) const { - assert(i < m_output_port_connectors.size() && - "Failed to get output port connector: target output port must be less than output count!"); + OPENVINO_ASSERT(i < m_output_port_connectors.size() && + "Failed to get output port connector: target output port must be less than output count!"); return m_output_port_connectors[i]; } const PortDescriptorPtr& Expression::get_input_port_descriptor(size_t i) const { - assert(i < m_input_port_descriptors.size() && - "Failed to get input port descriptor: target input port must be less than input count!"); + OPENVINO_ASSERT(i < m_input_port_descriptors.size() && + "Failed to get input port descriptor: target input port must be less than input count!"); return m_input_port_descriptors[i]; } const PortDescriptorPtr& Expression::get_output_port_descriptor(size_t i) const { - assert(i < m_output_port_descriptors.size() && - "Failed to get output port descriptor: target output port must be less than output count!"); + OPENVINO_ASSERT(i < m_output_port_descriptors.size() && + "Failed to get output port descriptor: target output port must be less than output count!"); return m_output_port_descriptors[i]; } +ExpressionPtr Expression::get_input_expr_ptr(size_t i) const { + return get_input_port_connector(i)->get_source().get_expr(); +} + std::shared_ptr Expression::get_node() const { if (!m_source_node) { OPENVINO_THROW("An attempt to get uninitialized node from lowered expression"); diff --git a/src/common/snippets/src/lowered/pass/assign_registers.cpp b/src/common/snippets/src/lowered/pass/assign_registers.cpp index ad99a3bc550299..5fa9ba19f84b4a 100644 --- a/src/common/snippets/src/lowered/pass/assign_registers.cpp +++ b/src/common/snippets/src/lowered/pass/assign_registers.cpp @@ -92,8 +92,7 @@ AssignRegisters::RegMap AssignRegisters::assign_regs_manually(const LinearIR& li const auto parent = tensor->get_source(); const auto parent_expr = parent.get_expr(); if (ov::is_type(parent_expr->get_node())) { - if (ov::is_type( - parent_expr->get_input_port_connector(0)->get_source().get_expr()->get_node())) { + if (ov::is_type(parent_expr->get_input_expr_ptr(0)->get_node())) { manually_assigned[parent.get_descriptor_ptr()->get_reg()] = manually_assigned[parent_expr->get_input_port_descriptor(0)->get_reg()] = assigned; } diff --git a/src/common/snippets/src/lowered/pass/clean_repeated_ptr_shifts.cpp b/src/common/snippets/src/lowered/pass/clean_repeated_ptr_shifts.cpp index e690966a1f9895..82744437ed5fe5 100644 --- a/src/common/snippets/src/lowered/pass/clean_repeated_ptr_shifts.cpp +++ b/src/common/snippets/src/lowered/pass/clean_repeated_ptr_shifts.cpp @@ -27,7 +27,6 @@ bool CleanRepeatedDataPointerShifts::reuse_increments(const LoopManagerPtr& loop return false; } - const auto& loop_connectors = loop_end_expr->get_input_port_connectors(); const auto input_count = loop_end->get_input_num(); const auto output_count = loop_end->get_output_num(); @@ -40,7 +39,7 @@ bool CleanRepeatedDataPointerShifts::reuse_increments(const LoopManagerPtr& loop // Load_0 Load_1 std::set read_data_exprs; for (size_t i = 0; i < input_count; ++i) { - const auto& parent_output = loop_connectors[i]->get_source().get_expr(); + const auto& parent_output = loop_end_expr->get_input_expr_ptr(i); if (const auto buffer_expr = ov::as_type_ptr(parent_output)) { // If Buffer is missed in set, Just save - it's first meeting if (buffers_groups.count(buffer_expr->get_reg_group()) == 0) { @@ -60,6 +59,8 @@ bool CleanRepeatedDataPointerShifts::reuse_increments(const LoopManagerPtr& loop } } } + + const auto& loop_connectors = loop_end_expr->get_input_port_connectors(); for (size_t i = 0; i < output_count; ++i) { const auto consumer_inputs = loop_connectors[input_count + i]->get_consumers(); size_t buffer_count = 0; diff --git a/src/common/snippets/src/lowered/pass/extract_loop_invariants.cpp b/src/common/snippets/src/lowered/pass/extract_loop_invariants.cpp index 33d62edce30232..83caa4d55e36f7 100644 --- a/src/common/snippets/src/lowered/pass/extract_loop_invariants.cpp +++ b/src/common/snippets/src/lowered/pass/extract_loop_invariants.cpp @@ -15,6 +15,7 @@ #include "openvino/core/type.hpp" #include "snippets/itt.hpp" #include "snippets/lowered/expression.hpp" +#include "snippets/lowered/expression_port.hpp" #include "snippets/lowered/linear_ir.hpp" #include "snippets/lowered/loop_info.hpp" #include "snippets/lowered/loop_manager.hpp" @@ -94,7 +95,7 @@ bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPt } for (size_t i = 0; i < input_port_size; ++i) { - const auto& parent = expr->get_input_port_connector(i)->get_source().get_expr(); + const auto& parent = expr->get_input_expr_ptr(i); bool parent_scalar_with_single_consumer = ov::is_type(parent->get_node()) && parent->get_output_port_connector(0)->get_consumers().size() == 1; const auto& is_loop_port = inner_loop_info->is_loop_port(expr_input_ports[i]); @@ -199,7 +200,7 @@ bool extract_from_loop(const size_t& inner_loop_id, LinearIR& linear_ir) { // extract scalar on inputs if there are for (size_t i = 0; i < port_expr->get_input_count(); ++i) { - auto parent = port_expr->get_input_port_connector(i)->get_source().get_expr(); + auto parent = port_expr->get_input_expr_ptr(i); if (ov::is_type(parent->get_node())) { extract_expr(parent, linear_ir, inner_loop_begin_pos, inner_loop_end_pos); } diff --git a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp index fd833e313cd11d..d4b041a983c41c 100644 --- a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp +++ b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp @@ -53,7 +53,7 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir, continue; } - const auto& load_parent_node = load_expr->get_input_port_connector(0)->get_source().get_expr()->get_node(); + const auto& load_parent_node = load_expr->get_input_expr_ptr(0)->get_node(); const auto& outshape = move_broadcast->get_output_partial_shape(0); const auto broadcastload = std::make_shared(load_parent_node, *outshape.rbegin(), load->get_offset()); diff --git a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp index 0e2adc57460c95..797737ba69ba6d 100644 --- a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp +++ b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp @@ -161,7 +161,7 @@ std::unordered_set MHAParallelWAOptimizer::find_unsqueezed_params( std::unordered_set visited; for (const auto& brgemm : brgemms) { - const auto& brgemm_b_input = brgemm->get_input_port_connector(1)->get_source().get_expr(); + const auto& brgemm_b_input = brgemm->get_input_expr_ptr(1); utils::visit_path(brgemm_b_input, visited, add_param, true); } return unsqueezed_params; diff --git a/src/common/snippets/src/lowered/pass/move_result_out_of_loop.cpp b/src/common/snippets/src/lowered/pass/move_result_out_of_loop.cpp index 9686300bf6c989..ade70473896501 100644 --- a/src/common/snippets/src/lowered/pass/move_result_out_of_loop.cpp +++ b/src/common/snippets/src/lowered/pass/move_result_out_of_loop.cpp @@ -33,8 +33,7 @@ bool MoveResultOutOfLoop::run(LinearIR& linear_ir) { continue; } - const auto& input_connector = expr->get_input_port_connector(0); - const auto& parent_expr = input_connector->get_source().get_expr(); + const auto& parent_expr = expr->get_input_expr_ptr(0); const auto& parent_loop_ids = parent_expr->get_loop_ids(); // Parent is out of Loop: just verify that Result is after Parent diff --git a/src/common/snippets/src/lowered/pass/serialize_data_flow.cpp b/src/common/snippets/src/lowered/pass/serialize_data_flow.cpp index 3d09696d857ada..0a213693ce63d7 100644 --- a/src/common/snippets/src/lowered/pass/serialize_data_flow.cpp +++ b/src/common/snippets/src/lowered/pass/serialize_data_flow.cpp @@ -38,7 +38,7 @@ bool SerializeDataFlow::run(const LinearIR& linear_ir) { const auto node = expr->get_node(); ov::OutputVector inputs(expr->get_input_count()); for (size_t i = 0; i < expr->get_input_count(); ++i) { - const auto& input_expr = expr->get_input_port_connector(i)->get_source().get_expr(); + const auto& input_expr = expr->get_input_expr_ptr(i); OPENVINO_ASSERT(ops_map.count(input_expr), "input node wasn't found during serialization"); inputs[i] = ops_map[input_expr]->output(expr->get_input_port_connector(i)->get_source().get_index()); } diff --git a/src/common/snippets/src/lowered/pass/validate.cpp b/src/common/snippets/src/lowered/pass/validate.cpp index d458e17cf4d213..3601a19a3d77fa 100644 --- a/src/common/snippets/src/lowered/pass/validate.cpp +++ b/src/common/snippets/src/lowered/pass/validate.cpp @@ -113,7 +113,7 @@ void validate_loop_end(const ExpressionPtr& expr, const LinearIR& linear_ir) { OPENVINO_ASSERT(loop_begin != nullptr, "LoopEnd must be connected to the LoopBegin"); const auto num_inputs = expr->get_input_count(); OPENVINO_ASSERT(num_inputs >= 1, "LoopEnd expression must have at least 1 input"); - OPENVINO_ASSERT(expr->get_input_port_connector(num_inputs - 1)->get_source().get_expr()->get_node() == loop_begin, + OPENVINO_ASSERT(expr->get_input_expr_ptr(num_inputs - 1)->get_node() == loop_begin, "LoopEnd expression must have LoopBegin attached to the last connector"); const auto& loop_manager = linear_ir.get_loop_manager(); diff --git a/src/common/snippets/src/utils/utils.cpp b/src/common/snippets/src/utils/utils.cpp index c946e1ed7b628c..ab9164061ed545 100644 --- a/src/common/snippets/src/utils/utils.cpp +++ b/src/common/snippets/src/utils/utils.cpp @@ -343,19 +343,17 @@ std::vector get_first_parent_shape_infer_expr_seq(const if (current_exp->get_input_count() == 0) { return shape_infer_exprs; } - auto input = current_exp->get_input_port_connector(0); - auto first_parent = input->get_source().get_expr(); + auto first_parent = current_exp->get_input_expr_ptr(0); while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) { shape_infer_exprs.push_back(first_parent); current_exp = first_parent; if (current_exp->get_input_count() == 0) { break; } - input = current_exp->get_input_port_connector(0); - first_parent = input->get_source().get_expr(); + first_parent = current_exp->get_input_expr_ptr(0); if (!ov::is_type(first_parent->get_node())) { // there are maybe some loopEnd consumers of store as well for loop code gen purpose - OPENVINO_ASSERT(input->get_consumers().size() == 1, + OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops."); } } diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.cpp index 16c0eb9e2688a0..420ff96b295258 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_loop_emitters.cpp @@ -133,7 +133,7 @@ jit_loop_end_emitter::jit_loop_end_emitter(dnnl::impl::cpu::aarch64::jit_generat ov::snippets::lowered::ExpressionPtr jit_loop_end_emitter::get_loop_begin_expr( const ov::snippets::lowered::ExpressionPtr& expr) { - auto begin_expr = expr->get_input_port_connectors().back()->get_source().get_expr(); + auto begin_expr = expr->get_input_expr_ptr(expr->get_input_count() - 1); OV_CPU_JIT_EMITTER_ASSERT(ov::is_type(begin_expr->get_node()), "LoopEnd expression must have th last port connector to LoopBegin"); return begin_expr; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp index 7f5cb5a689474c..c199a1942f27f1 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp @@ -77,7 +77,7 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h, size_t jit_memory_emitter::get_parent_buffer_cluster_id(const ov::snippets::lowered::ExpressionPtr& expr) { OV_CPU_JIT_EMITTER_ASSERT(expr->get_input_count() == 1, "MemoryAccess must have one parent"); - const auto& parent_expr = expr->get_input_port_connector(0)->get_source().get_expr(); + const auto& parent_expr = expr->get_input_expr_ptr(0); if (const auto buffer = ov::as_type_ptr(parent_expr)) { return buffer->get_cluster_id(); } diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_base_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_base_emitters.cpp index 7941af2d742c4a..b737f3dadb1f4f 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_base_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_base_emitters.cpp @@ -135,7 +135,7 @@ jit_loop_end_base_emitter::jit_loop_end_base_emitter(jit_generator_t* h, ov::snippets::lowered::ExpressionPtr jit_loop_end_base_emitter::get_loop_begin_expr( const ov::snippets::lowered::ExpressionPtr& expr) { - auto begin_expr = expr->get_input_port_connectors().back()->get_source().get_expr(); + auto begin_expr = expr->get_input_expr_ptr(expr->get_input_count() - 1); OV_CPU_JIT_EMITTER_ASSERT(ov::is_type(begin_expr->get_node()), "LoopEnd expression must have the last port connector to LoopBegin"); return begin_expr; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp index 4e3ab4de383d1b..87bce870697263 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp @@ -81,7 +81,7 @@ size_t jit_memory_emitter::aux_gprs_count() const { size_t jit_memory_emitter::get_parent_buffer_cluster_id(const ov::snippets::lowered::ExpressionPtr& expr) { OV_CPU_JIT_EMITTER_ASSERT(expr->get_input_port_connectors().size() == 1, "MemoryAccess must have one parent"); - const auto& parent_expr = expr->get_input_port_connector(0)->get_source().get_expr(); + const auto& parent_expr = expr->get_input_expr_ptr(0); if (const auto buffer = ov::as_type_ptr(parent_expr)) { return buffer->get_cluster_id(); } diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_reg_spill_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_reg_spill_emitters.cpp index cb46b7f0f4be6f..3e79ae2850e875 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_reg_spill_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_reg_spill_emitters.cpp @@ -70,7 +70,7 @@ jit_reg_spill_end_emitter::jit_reg_spill_end_emitter(dnnl::impl::cpu::x64::jit_g in_out_type_ = emitter_in_out_map::gpr_to_gpr; OV_CPU_JIT_EMITTER_ASSERT(ov::is_type(expr->get_node()) && expr->get_input_count() > 0, "Invalid expression in RegSpillEnd emitter"); - const auto& parent_expr = expr->get_input_port_connector(0)->get_source().get_expr(); + const auto& parent_expr = expr->get_input_expr_ptr(0); const auto& reg_spill_begin_emitter = std::dynamic_pointer_cast(parent_expr->get_emitter()); OV_CPU_JIT_EMITTER_ASSERT(reg_spill_begin_emitter, "Failed to obtain reg_spill_begin emitter"); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/op/gemm_utils.cpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/op/gemm_utils.cpp index 333c9ab43502a3..205db46e9c7472 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/op/gemm_utils.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/op/gemm_utils.cpp @@ -20,14 +20,14 @@ namespace ov::intel_cpu::aarch64::gemm_utils::repacking { ov::snippets::lowered::ExpressionPtr get_copy_b_expr(const ov::snippets::lowered::ExpressionPtr& gemm_expr) { OPENVINO_ASSERT(ov::is_type(gemm_expr->get_node()), "get_copy_b_expr must be called only for GemmCPU node"); - auto b_input_expr = gemm_expr->get_input_port_connector(1)->get_source().get_expr(); + auto b_input_expr = gemm_expr->get_input_expr_ptr(1); if (ov::is_type(b_input_expr->get_node())) { return b_input_expr; } if (ov::is_type(b_input_expr)) { OPENVINO_ASSERT(b_input_expr->get_input_count() == 1, "RepackedWeightsBufferExpression on gemm's B input must has one input"); - auto input_buffer_expr = b_input_expr->get_input_port_connector(0)->get_source().get_expr(); + auto input_buffer_expr = b_input_expr->get_input_expr_ptr(0); if (ov::is_type(input_buffer_expr->get_node())) { return input_buffer_expr; } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/lowered/expressions/gemm_copy_b_buffer_expressions.cpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/lowered/expressions/gemm_copy_b_buffer_expressions.cpp index 5b61de7ffe18f9..f53de4a168d083 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/lowered/expressions/gemm_copy_b_buffer_expressions.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/lowered/expressions/gemm_copy_b_buffer_expressions.cpp @@ -47,7 +47,7 @@ void RepackedWeightsBufferExpression::validate() const { void RepackedWeightsBufferExpression::init_allocation_size( [[maybe_unused]] const std::shared_ptr& loop_manager, [[maybe_unused]] size_t allocation_rank) { - const auto& parent_expr = get_input_port_connector(0)->get_source().get_expr(); + const auto& parent_expr = get_input_expr_ptr(0); const auto& in_shape = ov::snippets::utils::get_planar_vdims(parent_expr->get_input_port(0)); OPENVINO_ASSERT(in_shape.size() >= 2 && allocation_rank >= 2, "GemmCopyB should has at least 2 rank tensor"); const auto& element_type = get_node()->get_input_element_type(0); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp index 6d8068ba3ae050..0c9fb297c0afa6 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp @@ -248,14 +248,14 @@ ov::snippets::op::Subgraph::BlockedShape get_wei_blocked_shape(const ov::snippet ov::snippets::lowered::ExpressionPtr get_copy_b_expr(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) { OPENVINO_ASSERT(ov::is_type(brgemm_expr->get_node()), "get_copy_b_expr must be called only for BrgemmCPU node"); - auto b_input_expr = brgemm_expr->get_input_port_connector(1)->get_source().get_expr(); + auto b_input_expr = brgemm_expr->get_input_expr_ptr(1); if (ov::is_type(b_input_expr->get_node())) { return b_input_expr; } if (ov::is_type(b_input_expr)) { OPENVINO_ASSERT(b_input_expr->get_input_count() >= 1, "RepackedWeightsBufferExpression on brgemm's B input must have at least one input"); - auto input_buffer_expr = b_input_expr->get_input_port_connector(0)->get_source().get_expr(); + auto input_buffer_expr = b_input_expr->get_input_expr_ptr(0); if (ov::is_type(input_buffer_expr->get_node())) { return input_buffer_expr; } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp index ee06332e0d6a3b..629743d30fe65b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp @@ -50,7 +50,7 @@ std::shared_ptr BrgemmCPUBlocking::DummyPass: LinearIR::constExprIt BrgemmCPUBlocking::move_new_memory_buffer(LinearIR& linear_ir, const LinearIR::constExprIt& brgemm_it) { const auto& brgemm_expr = brgemm_it->get(); - const auto wsp_expr = brgemm_expr->get_input_port_connector(2)->get_source().get_expr(); + const auto wsp_expr = brgemm_expr->get_input_expr_ptr(2); const auto wsp_buffer = ov::as_type_ptr(wsp_expr); OPENVINO_ASSERT(wsp_buffer && wsp_buffer->is_independent_memory(), "Incorrect Scratchpad buffer for Brgemm AMX"); // If scratchpad with temp memory is not explicitly before Brgemm, need to move to there. diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp index 856d4556dbe1b5..22033e42429cb0 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp @@ -45,7 +45,7 @@ void RepackedWeightsBufferExpression::validate() const { void RepackedWeightsBufferExpression::init_allocation_size( [[maybe_unused]] const std::shared_ptr& loop_manager, [[maybe_unused]] size_t allocation_rank) { - const auto& parent_expr = get_input_port_connector(0)->get_source().get_expr(); + const auto& parent_expr = get_input_expr_ptr(0); const auto& brgemm_copy_b = ov::as_type_ptr(parent_expr->get_node()); OPENVINO_ASSERT(brgemm_copy_b, "RepackedWeightsBufferExpression expects BrgemmCopyB as parent expression"); const auto& brgemm_config = brgemm_copy_b->get_config(); @@ -88,7 +88,7 @@ void CompensationsBufferExpression::validate() const { void CompensationsBufferExpression::init_allocation_size( [[maybe_unused]] const std::shared_ptr& loop_manager, [[maybe_unused]] size_t allocation_rank) { - const auto& parent_expr = get_input_port_connector(0)->get_source().get_expr(); + const auto& parent_expr = get_input_expr_ptr(0); const auto& brgemm_copy_b = ov::as_type_ptr(parent_expr->get_node()); OPENVINO_ASSERT(brgemm_copy_b, "RepackedWeightsBufferExpression expects BrgemmCopyB as parent expression"); const auto& brgemm_config = brgemm_copy_b->get_config(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.cpp index ae645a1e448290..ef1929f3e1b425 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.cpp @@ -108,8 +108,8 @@ bool InsertBrgemmCopyBuffers::run(LinearIR& linear_ir, LinearIR::constExprIt beg } if (brgemm_config.is_amx()) { - const auto& scratch_expr = ov::as_type_ptr( - brgemm_expr->get_input_port_connector(2)->get_source().get_expr()); + const auto& scratch_expr = + ov::as_type_ptr(brgemm_expr->get_input_expr_ptr(2)); update_scratchpad(brgemm_expr, scratch_expr); modified = true; } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/parallelize_gated_mlp_n_loops.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/parallelize_gated_mlp_n_loops.cpp index 2898b5ba62f54f..b913e1a79140ef 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/parallelize_gated_mlp_n_loops.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/parallelize_gated_mlp_n_loops.cpp @@ -33,9 +33,8 @@ bool ParallelizeGatedMlpNLoops::run(LinearIR& linear_ir, LinearIR::constExprIt b } // Gated MLP pattern contains 3 brgemms, 2 first brgemms have the same A input - const bool is_gated_mlp = brgemm_expressions.size() == 3 && - brgemm_expressions[0]->get_input_port_connector(0)->get_source().get_expr() == - brgemm_expressions[1]->get_input_port_connector(0)->get_source().get_expr(); + const bool is_gated_mlp = brgemm_expressions.size() == 3 && brgemm_expressions[0]->get_input_expr_ptr(0) == + brgemm_expressions[1]->get_input_expr_ptr(0); if (!is_gated_mlp) { return false; }