Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/common/snippets/include/snippets/lowered/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class Expression : public std::enable_shared_from_this<Expression> {
return m_output_port_descriptors;
}

ExpressionPtr get_input_expr_ptr(size_t i) const;

size_t get_input_count() const {
return m_input_port_connectors.size();
}
Expand Down
20 changes: 12 additions & 8 deletions src/common/snippets/src/lowered/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,31 @@ Expression::Expression(const std::shared_ptr<Node>& n,
}

const PortConnectorPtr& Expression::get_input_port_connector(size_t i) const {
assert(i < m_input_port_connectors.size() &&
"Failed to get input port connector: target input port must be less than input count!");
OPENVINO_ASSERT(i < m_input_port_connectors.size() &&
"Failed to get input port connector: target input port must be less than input count!");
return m_input_port_connectors[i];
}
const PortConnectorPtr& Expression::get_output_port_connector(size_t i) const {
assert(i < m_output_port_connectors.size() &&
"Failed to get output port connector: target output port must be less than output count!");
OPENVINO_ASSERT(i < m_output_port_connectors.size() &&
"Failed to get output port connector: target output port must be less than output count!");
return m_output_port_connectors[i];
}

const PortDescriptorPtr& Expression::get_input_port_descriptor(size_t i) const {
assert(i < m_input_port_descriptors.size() &&
"Failed to get input port descriptor: target input port must be less than input count!");
OPENVINO_ASSERT(i < m_input_port_descriptors.size() &&
"Failed to get input port descriptor: target input port must be less than input count!");
return m_input_port_descriptors[i];
}
const PortDescriptorPtr& Expression::get_output_port_descriptor(size_t i) const {
assert(i < m_output_port_descriptors.size() &&
"Failed to get output port descriptor: target output port must be less than output count!");
OPENVINO_ASSERT(i < m_output_port_descriptors.size() &&
"Failed to get output port descriptor: target output port must be less than output count!");
return m_output_port_descriptors[i];
}

ExpressionPtr Expression::get_input_expr_ptr(size_t i) const {
return get_input_port_connector(i)->get_source().get_expr();
}

std::shared_ptr<Node> Expression::get_node() const {
if (!m_source_node) {
OPENVINO_THROW("An attempt to get uninitialized node from lowered expression");
Expand Down
3 changes: 1 addition & 2 deletions src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ AssignRegisters::RegMap AssignRegisters::assign_regs_manually(const LinearIR& li
const auto parent = tensor->get_source();
const auto parent_expr = parent.get_expr();
if (ov::is_type<op::Fill>(parent_expr->get_node())) {
if (ov::is_type<op::VectorBuffer>(
parent_expr->get_input_port_connector(0)->get_source().get_expr()->get_node())) {
if (ov::is_type<op::VectorBuffer>(parent_expr->get_input_expr_ptr(0)->get_node())) {
manually_assigned[parent.get_descriptor_ptr()->get_reg()] =
manually_assigned[parent_expr->get_input_port_descriptor(0)->get_reg()] = assigned;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ bool CleanRepeatedDataPointerShifts::reuse_increments(const LoopManagerPtr& loop
return false;
}

const auto& loop_connectors = loop_end_expr->get_input_port_connectors();
const auto input_count = loop_end->get_input_num();
const auto output_count = loop_end->get_output_num();

Expand All @@ -40,7 +39,7 @@ bool CleanRepeatedDataPointerShifts::reuse_increments(const LoopManagerPtr& loop
// Load_0 Load_1
std::set<ExpressionPtr> read_data_exprs;
for (size_t i = 0; i < input_count; ++i) {
const auto& parent_output = loop_connectors[i]->get_source().get_expr();
const auto& parent_output = loop_end_expr->get_input_expr_ptr(i);
if (const auto buffer_expr = ov::as_type_ptr<BufferExpression>(parent_output)) {
// If Buffer is missed in set, Just save - it's first meeting
if (buffers_groups.count(buffer_expr->get_reg_group()) == 0) {
Expand All @@ -60,6 +59,8 @@ bool CleanRepeatedDataPointerShifts::reuse_increments(const LoopManagerPtr& loop
}
}
}

const auto& loop_connectors = loop_end_expr->get_input_port_connectors();
for (size_t i = 0; i < output_count; ++i) {
const auto consumer_inputs = loop_connectors[input_count + i]->get_consumers();
size_t buffer_count = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "openvino/core/type.hpp"
#include "snippets/itt.hpp"
#include "snippets/lowered/expression.hpp"
#include "snippets/lowered/expression_port.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/lowered/loop_manager.hpp"
Expand Down Expand Up @@ -94,7 +95,7 @@ bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPt
}

for (size_t i = 0; i < input_port_size; ++i) {
const auto& parent = expr->get_input_port_connector(i)->get_source().get_expr();
const auto& parent = expr->get_input_expr_ptr(i);
bool parent_scalar_with_single_consumer = ov::is_type<snippets::op::Scalar>(parent->get_node()) &&
parent->get_output_port_connector(0)->get_consumers().size() == 1;
const auto& is_loop_port = inner_loop_info->is_loop_port(expr_input_ports[i]);
Expand Down Expand Up @@ -199,7 +200,7 @@ bool extract_from_loop(const size_t& inner_loop_id, LinearIR& linear_ir) {

// extract scalar on inputs if there are
for (size_t i = 0; i < port_expr->get_input_count(); ++i) {
auto parent = port_expr->get_input_port_connector(i)->get_source().get_expr();
auto parent = port_expr->get_input_expr_ptr(i);
if (ov::is_type<snippets::op::Scalar>(parent->get_node())) {
extract_expr(parent, linear_ir, inner_loop_begin_pos, inner_loop_end_pos);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir,
continue;
}

const auto& load_parent_node = load_expr->get_input_port_connector(0)->get_source().get_expr()->get_node();
const auto& load_parent_node = load_expr->get_input_expr_ptr(0)->get_node();
const auto& outshape = move_broadcast->get_output_partial_shape(0);
const auto broadcastload =
std::make_shared<snippets::op::BroadcastLoad>(load_parent_node, *outshape.rbegin(), load->get_offset());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ std::unordered_set<size_t> MHAParallelWAOptimizer::find_unsqueezed_params(

std::unordered_set<lowered::ExpressionPtr> visited;
for (const auto& brgemm : brgemms) {
const auto& brgemm_b_input = brgemm->get_input_port_connector(1)->get_source().get_expr();
const auto& brgemm_b_input = brgemm->get_input_expr_ptr(1);
utils::visit_path(brgemm_b_input, visited, add_param, true);
}
return unsqueezed_params;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ bool MoveResultOutOfLoop::run(LinearIR& linear_ir) {
continue;
}

const auto& input_connector = expr->get_input_port_connector(0);
const auto& parent_expr = input_connector->get_source().get_expr();
const auto& parent_expr = expr->get_input_expr_ptr(0);
const auto& parent_loop_ids = parent_expr->get_loop_ids();

// Parent is out of Loop: just verify that Result is after Parent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ bool SerializeDataFlow::run(const LinearIR& linear_ir) {
const auto node = expr->get_node();
ov::OutputVector inputs(expr->get_input_count());
for (size_t i = 0; i < expr->get_input_count(); ++i) {
const auto& input_expr = expr->get_input_port_connector(i)->get_source().get_expr();
const auto& input_expr = expr->get_input_expr_ptr(i);
OPENVINO_ASSERT(ops_map.count(input_expr), "input node wasn't found during serialization");
inputs[i] = ops_map[input_expr]->output(expr->get_input_port_connector(i)->get_source().get_index());
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/pass/validate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ void validate_loop_end(const ExpressionPtr& expr, const LinearIR& linear_ir) {
OPENVINO_ASSERT(loop_begin != nullptr, "LoopEnd must be connected to the LoopBegin");
const auto num_inputs = expr->get_input_count();
OPENVINO_ASSERT(num_inputs >= 1, "LoopEnd expression must have at least 1 input");
OPENVINO_ASSERT(expr->get_input_port_connector(num_inputs - 1)->get_source().get_expr()->get_node() == loop_begin,
OPENVINO_ASSERT(expr->get_input_expr_ptr(num_inputs - 1)->get_node() == loop_begin,
"LoopEnd expression must have LoopBegin attached to the last connector");

const auto& loop_manager = linear_ir.get_loop_manager();
Expand Down
8 changes: 3 additions & 5 deletions src/common/snippets/src/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,19 +343,17 @@ std::vector<lowered::ExpressionPtr> get_first_parent_shape_infer_expr_seq(const
if (current_exp->get_input_count() == 0) {
return shape_infer_exprs;
}
auto input = current_exp->get_input_port_connector(0);
auto first_parent = input->get_source().get_expr();
auto first_parent = current_exp->get_input_expr_ptr(0);
while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) {
shape_infer_exprs.push_back(first_parent);
current_exp = first_parent;
if (current_exp->get_input_count() == 0) {
break;
}
input = current_exp->get_input_port_connector(0);
first_parent = input->get_source().get_expr();
first_parent = current_exp->get_input_expr_ptr(0);
if (!ov::is_type<snippets::op::Store>(first_parent->get_node())) {
// there are maybe some loopEnd consumers of store as well for loop code gen purpose
OPENVINO_ASSERT(input->get_consumers().size() == 1,
OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1,
"Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops.");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ jit_loop_end_emitter::jit_loop_end_emitter(dnnl::impl::cpu::aarch64::jit_generat

ov::snippets::lowered::ExpressionPtr jit_loop_end_emitter::get_loop_begin_expr(
const ov::snippets::lowered::ExpressionPtr& expr) {
auto begin_expr = expr->get_input_port_connectors().back()->get_source().get_expr();
auto begin_expr = expr->get_input_expr_ptr(expr->get_input_count() - 1);
OV_CPU_JIT_EMITTER_ASSERT(ov::is_type<snippets::op::LoopBegin>(begin_expr->get_node()),
"LoopEnd expression must have th last port connector to LoopBegin");
return begin_expr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h,

size_t jit_memory_emitter::get_parent_buffer_cluster_id(const ov::snippets::lowered::ExpressionPtr& expr) {
OV_CPU_JIT_EMITTER_ASSERT(expr->get_input_count() == 1, "MemoryAccess must have one parent");
const auto& parent_expr = expr->get_input_port_connector(0)->get_source().get_expr();
const auto& parent_expr = expr->get_input_expr_ptr(0);
if (const auto buffer = ov::as_type_ptr<ov::snippets::lowered::BufferExpression>(parent_expr)) {
return buffer->get_cluster_id();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ jit_loop_end_base_emitter::jit_loop_end_base_emitter(jit_generator_t* h,

ov::snippets::lowered::ExpressionPtr jit_loop_end_base_emitter::get_loop_begin_expr(
const ov::snippets::lowered::ExpressionPtr& expr) {
auto begin_expr = expr->get_input_port_connectors().back()->get_source().get_expr();
auto begin_expr = expr->get_input_expr_ptr(expr->get_input_count() - 1);
OV_CPU_JIT_EMITTER_ASSERT(ov::is_type<snippets::op::LoopBegin>(begin_expr->get_node()),
"LoopEnd expression must have the last port connector to LoopBegin");
return begin_expr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ size_t jit_memory_emitter::aux_gprs_count() const {

size_t jit_memory_emitter::get_parent_buffer_cluster_id(const ov::snippets::lowered::ExpressionPtr& expr) {
OV_CPU_JIT_EMITTER_ASSERT(expr->get_input_port_connectors().size() == 1, "MemoryAccess must have one parent");
const auto& parent_expr = expr->get_input_port_connector(0)->get_source().get_expr();
const auto& parent_expr = expr->get_input_expr_ptr(0);
if (const auto buffer = ov::as_type_ptr<ov::snippets::lowered::BufferExpression>(parent_expr)) {
return buffer->get_cluster_id();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jit_reg_spill_end_emitter::jit_reg_spill_end_emitter(dnnl::impl::cpu::x64::jit_g
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
OV_CPU_JIT_EMITTER_ASSERT(ov::is_type<snippets::op::RegSpillEnd>(expr->get_node()) && expr->get_input_count() > 0,
"Invalid expression in RegSpillEnd emitter");
const auto& parent_expr = expr->get_input_port_connector(0)->get_source().get_expr();
const auto& parent_expr = expr->get_input_expr_ptr(0);
const auto& reg_spill_begin_emitter =
std::dynamic_pointer_cast<jit_reg_spill_begin_emitter>(parent_expr->get_emitter());
OV_CPU_JIT_EMITTER_ASSERT(reg_spill_begin_emitter, "Failed to obtain reg_spill_begin emitter");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ namespace ov::intel_cpu::aarch64::gemm_utils::repacking {
ov::snippets::lowered::ExpressionPtr get_copy_b_expr(const ov::snippets::lowered::ExpressionPtr& gemm_expr) {
OPENVINO_ASSERT(ov::is_type<GemmCPU>(gemm_expr->get_node()),
"get_copy_b_expr must be called only for GemmCPU node");
auto b_input_expr = gemm_expr->get_input_port_connector(1)->get_source().get_expr();
auto b_input_expr = gemm_expr->get_input_expr_ptr(1);
if (ov::is_type<GemmCopyB>(b_input_expr->get_node())) {
return b_input_expr;
}
if (ov::is_type<RepackedWeightsBufferExpression>(b_input_expr)) {
OPENVINO_ASSERT(b_input_expr->get_input_count() == 1,
"RepackedWeightsBufferExpression on gemm's B input must has one input");
auto input_buffer_expr = b_input_expr->get_input_port_connector(0)->get_source().get_expr();
auto input_buffer_expr = b_input_expr->get_input_expr_ptr(0);
if (ov::is_type<GemmCopyB>(input_buffer_expr->get_node())) {
return input_buffer_expr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void RepackedWeightsBufferExpression::validate() const {
void RepackedWeightsBufferExpression::init_allocation_size(
[[maybe_unused]] const std::shared_ptr<snippets::lowered::LoopManager>& loop_manager,
[[maybe_unused]] size_t allocation_rank) {
const auto& parent_expr = get_input_port_connector(0)->get_source().get_expr();
const auto& parent_expr = get_input_expr_ptr(0);
const auto& in_shape = ov::snippets::utils::get_planar_vdims(parent_expr->get_input_port(0));
OPENVINO_ASSERT(in_shape.size() >= 2 && allocation_rank >= 2, "GemmCopyB should has at least 2 rank tensor");
const auto& element_type = get_node()->get_input_element_type(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,14 @@ ov::snippets::op::Subgraph::BlockedShape get_wei_blocked_shape(const ov::snippet
ov::snippets::lowered::ExpressionPtr get_copy_b_expr(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) {
OPENVINO_ASSERT(ov::is_type<BrgemmCPU>(brgemm_expr->get_node()),
"get_copy_b_expr must be called only for BrgemmCPU node");
auto b_input_expr = brgemm_expr->get_input_port_connector(1)->get_source().get_expr();
auto b_input_expr = brgemm_expr->get_input_expr_ptr(1);
if (ov::is_type<BrgemmCopyB>(b_input_expr->get_node())) {
return b_input_expr;
}
if (ov::is_type<RepackedWeightsBufferExpression>(b_input_expr)) {
OPENVINO_ASSERT(b_input_expr->get_input_count() >= 1,
"RepackedWeightsBufferExpression on brgemm's B input must have at least one input");
auto input_buffer_expr = b_input_expr->get_input_port_connector(0)->get_source().get_expr();
auto input_buffer_expr = b_input_expr->get_input_expr_ptr(0);
if (ov::is_type<BrgemmCopyB>(input_buffer_expr->get_node())) {
return input_buffer_expr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::shared_ptr<snippets::lowered::pass::PassBase> BrgemmCPUBlocking::DummyPass:
LinearIR::constExprIt BrgemmCPUBlocking::move_new_memory_buffer(LinearIR& linear_ir,
const LinearIR::constExprIt& brgemm_it) {
const auto& brgemm_expr = brgemm_it->get();
const auto wsp_expr = brgemm_expr->get_input_port_connector(2)->get_source().get_expr();
const auto wsp_expr = brgemm_expr->get_input_expr_ptr(2);
const auto wsp_buffer = ov::as_type_ptr<ov::snippets::lowered::BufferExpression>(wsp_expr);
OPENVINO_ASSERT(wsp_buffer && wsp_buffer->is_independent_memory(), "Incorrect Scratchpad buffer for Brgemm AMX");
// If scratchpad with temp memory is not explicitly before Brgemm, need to move to there.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void RepackedWeightsBufferExpression::validate() const {
void RepackedWeightsBufferExpression::init_allocation_size(
[[maybe_unused]] const std::shared_ptr<snippets::lowered::LoopManager>& loop_manager,
[[maybe_unused]] size_t allocation_rank) {
const auto& parent_expr = get_input_port_connector(0)->get_source().get_expr();
const auto& parent_expr = get_input_expr_ptr(0);
const auto& brgemm_copy_b = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyB>(parent_expr->get_node());
OPENVINO_ASSERT(brgemm_copy_b, "RepackedWeightsBufferExpression expects BrgemmCopyB as parent expression");
const auto& brgemm_config = brgemm_copy_b->get_config();
Expand Down Expand Up @@ -88,7 +88,7 @@ void CompensationsBufferExpression::validate() const {
void CompensationsBufferExpression::init_allocation_size(
[[maybe_unused]] const std::shared_ptr<snippets::lowered::LoopManager>& loop_manager,
[[maybe_unused]] size_t allocation_rank) {
const auto& parent_expr = get_input_port_connector(0)->get_source().get_expr();
const auto& parent_expr = get_input_expr_ptr(0);
const auto& brgemm_copy_b = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyB>(parent_expr->get_node());
OPENVINO_ASSERT(brgemm_copy_b, "RepackedWeightsBufferExpression expects BrgemmCopyB as parent expression");
const auto& brgemm_config = brgemm_copy_b->get_config();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ bool InsertBrgemmCopyBuffers::run(LinearIR& linear_ir, LinearIR::constExprIt beg
}

if (brgemm_config.is_amx()) {
const auto& scratch_expr = ov::as_type_ptr<ov::snippets::lowered::BufferExpression>(
brgemm_expr->get_input_port_connector(2)->get_source().get_expr());
const auto& scratch_expr =
ov::as_type_ptr<ov::snippets::lowered::BufferExpression>(brgemm_expr->get_input_expr_ptr(2));
update_scratchpad(brgemm_expr, scratch_expr);
modified = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ bool ParallelizeGatedMlpNLoops::run(LinearIR& linear_ir, LinearIR::constExprIt b
}

// Gated MLP pattern contains 3 brgemms, 2 first brgemms have the same A input
const bool is_gated_mlp = brgemm_expressions.size() == 3 &&
brgemm_expressions[0]->get_input_port_connector(0)->get_source().get_expr() ==
brgemm_expressions[1]->get_input_port_connector(0)->get_source().get_expr();
const bool is_gated_mlp = brgemm_expressions.size() == 3 && brgemm_expressions[0]->get_input_expr_ptr(0) ==
brgemm_expressions[1]->get_input_expr_ptr(0);
if (!is_gated_mlp) {
return false;
}
Expand Down
Loading