Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ set(PYPTO_SOURCES
src/ir/transforms/utils/parent_stmt_analysis.cpp
src/ir/transforms/utils/stmt_dependency_analysis.cpp
src/ir/transforms/utils/transform_utils.cpp
src/ir/transforms/utils/wrapper_call_utils.cpp
src/ir/transforms/visitor.cpp

# IR - Reporter
Expand Down
83 changes: 83 additions & 0 deletions include/pypto/ir/transforms/utils/wrapper_call_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) PyPTO Contributors.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
* -----------------------------------------------------------------------------------------------------------
*/

#ifndef PYPTO_IR_TRANSFORMS_UTILS_WRAPPER_CALL_UTILS_H_
#define PYPTO_IR_TRANSFORMS_UTILS_WRAPPER_CALL_UTILS_H_

#include <string>
#include <vector>

#include "pypto/ir/expr.h"
#include "pypto/ir/function.h"
#include "pypto/ir/program.h"

namespace pypto {
namespace ir {

/**
* @brief Result of a wrapper / inner-call lookup.
*
* Both fields are nullptr if no matching call was found.
*/
struct WrapperCallInfo {
CallPtr inner_call;
FunctionPtr inner_callee;
};

/**
* @brief Find the first non-builtin Call inside @p wrapper that resolves to a
* Function in @p program.
*
* "Non-builtin" here means the Call's op is a GlobalVar that names an
* existing user-level Function in the program. Builtin op calls
* (`tile.*`, `tensor.*`, `system.*`) carry no GlobalVar and are skipped.
*
* @return {call, callee} for the first match, or {nullptr, nullptr} if none.
*/
WrapperCallInfo FindFirstInnerCall(const FunctionPtr& wrapper, const ProgramPtr& program);

/**
* @brief Result of a Group-function callee scan.
*
* `aic_name` / `aiv_name` are the names of the first AIC / AIV callees
* encountered (empty if none). `inner_call` / `inner_callee` point at the
* first AIC, AIV, or InCore call (priority AIC > AIV > InCore) — the call
* whose argument shape orchestration codegen reorders against.
*/
struct GroupCalleeInfo {
std::string aic_name;
std::string aiv_name;
CallPtr inner_call;
FunctionPtr inner_callee;
};

/**
* @brief Group-specific scan: locate the AIC / AIV callees and the first
* InCore-variant inner call inside @p group_func.
*
* @return aggregated info; any field may be empty / nullptr if not present.
*/
GroupCalleeInfo FindGroupCallees(const FunctionPtr& group_func, const ProgramPtr& program);

/**
* @brief Collect every Call inside @p wrapper that resolves to a Function
* of a non-Orchestration, non-Opaque type.
*
* Used by cross-function direction propagation in `ComputeGroupEffectiveDirections`.
* Visits the body in order; each inner Call appears once even if its callee is
* called from multiple sites.
*/
std::vector<WrapperCallInfo> CollectInnerCalls(const FunctionPtr& wrapper, const ProgramPtr& program);

} // namespace ir
} // namespace pypto

#endif // PYPTO_IR_TRANSFORMS_UTILS_WRAPPER_CALL_UTILS_H_
28 changes: 4 additions & 24 deletions src/codegen/orchestration/orchestration_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "pypto/ir/stmt.h"
#include "pypto/ir/transforms/base/visitor.h"
#include "pypto/ir/transforms/utils/auto_name_utils.h"
#include "pypto/ir/transforms/utils/wrapper_call_utils.h"
#include "pypto/ir/type.h"

namespace pypto {
Expand Down Expand Up @@ -336,35 +337,14 @@ std::vector<ParamDirection> ComputeGroupEffectiveDirections(const FunctionPtr& g
return declared;
}

class InnerCallFinder : public IRVisitor {
public:
explicit InnerCallFinder(const ProgramPtr& program) : program_(program) {}
const ProgramPtr& program_;
std::vector<std::pair<CallPtr, FunctionPtr>> inner_calls;

protected:
void VisitExpr_(const CallPtr& call) override {
if (auto gv = As<GlobalVar>(call->op_)) {
auto callee = program_->GetFunction(gv->name_);
if (callee && callee->func_type_ != FunctionType::Orchestration &&
callee->func_type_ != FunctionType::Opaque) {
inner_calls.emplace_back(call, callee);
return;
}
}
IRVisitor::VisitExpr_(call);
}
};

InnerCallFinder finder(program);
finder.VisitStmt(func->body_);
if (!finder.inner_calls.empty()) {
auto inner_calls = ir::CollectInnerCalls(func, program);
if (!inner_calls.empty()) {
std::unordered_map<const Var*, size_t> param_to_index;
for (size_t i = 0; i < func->params_.size(); ++i) {
param_to_index[func->params_[i].get()] = i;
}

for (const auto& [inner_call, inner_callee] : finder.inner_calls) {
for (const auto& [inner_call, inner_callee] : inner_calls) {
const auto& inner_args = inner_call->args_;
std::vector<ParamDirection> inner_dirs;
if (inner_callee->func_type_ == FunctionType::Group ||
Expand Down
71 changes: 6 additions & 65 deletions src/codegen/orchestration/orchestration_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "pypto/ir/transforms/utils/op_predicates.h"
#include "pypto/ir/transforms/utils/transform_utils.h"
#include "pypto/ir/transforms/utils/var_collectors.h"
#include "pypto/ir/transforms/utils/wrapper_call_utils.h"
#include "pypto/ir/type.h"

namespace pypto {
Expand Down Expand Up @@ -899,76 +900,16 @@ class OrchestrationStmtCodegen : public CodegenBase {
};

WrapperCallInfo FindWrapperInnerCall(const FunctionPtr& wrapper_func) {
class InnerCallFinder : public IRVisitor {
public:
explicit InnerCallFinder(const ProgramPtr& program) : program_(program) {}
const ProgramPtr& program_;
CallPtr inner_call;
FunctionPtr inner_callee;

protected:
void VisitExpr_(const CallPtr& call) override {
if (inner_call) return;
if (auto gv = As<GlobalVar>(call->op_)) {
auto callee = program_->GetFunction(gv->name_);
if (callee) {
inner_call = call;
inner_callee = callee;
return;
}
}
IRVisitor::VisitExpr_(call);
}
};

InnerCallFinder finder(program_);
finder.VisitStmt(wrapper_func->body_);
return {std::move(finder.inner_call), std::move(finder.inner_callee)};
auto info = ir::FindFirstInnerCall(wrapper_func, program_);
return {std::move(info.inner_call), std::move(info.inner_callee)};
}

/// Walk the Group function body to find the AIC and AIV callee names
/// and the inner InCore call (needed for param reordering).
GroupCalleeInfo FindGroupCallees(const FunctionPtr& group_func) {
class CalleeFinder : public IRVisitor {
public:
explicit CalleeFinder(const ProgramPtr& program) : program_(program) {}
const ProgramPtr& program_;
std::string aic_name;
std::string aiv_name;
CallPtr inner_call;
FunctionPtr inner_callee;

protected:
void VisitExpr_(const CallPtr& call) override {
if (auto gv = As<GlobalVar>(call->op_)) {
auto callee = program_->GetFunction(gv->name_);
if (callee) {
if (callee->func_type_ == FunctionType::AIC && aic_name.empty()) {
aic_name = callee->name_;
if (!inner_call) {
inner_call = call;
inner_callee = callee;
}
} else if (callee->func_type_ == FunctionType::AIV && aiv_name.empty()) {
aiv_name = callee->name_;
if (!inner_call) {
inner_call = call;
inner_callee = callee;
}
} else if (callee->func_type_ == FunctionType::InCore && !inner_call) {
inner_call = call;
inner_callee = callee;
}
}
}
IRVisitor::VisitExpr_(call);
}
};

CalleeFinder finder(program_);
finder.VisitStmt(group_func->body_);
return {std::move(finder.aic_name), std::move(finder.aiv_name), std::move(finder.inner_call),
std::move(finder.inner_callee)};
auto info = ir::FindGroupCallees(group_func, program_);
return {std::move(info.aic_name), std::move(info.aiv_name), std::move(info.inner_call),
std::move(info.inner_callee)};
}

/// Build task params for a wrapper function call, reordered to match the
Expand Down
110 changes: 110 additions & 0 deletions src/ir/transforms/utils/wrapper_call_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright (c) PyPTO Contributors.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
* -----------------------------------------------------------------------------------------------------------
*/

#include "pypto/ir/transforms/utils/wrapper_call_utils.h"

#include <functional>
#include <utility>
#include <vector>

#include "pypto/ir/kind_traits.h"
#include "pypto/ir/transforms/base/visitor.h"

namespace pypto {
namespace ir {

namespace {

/// Shared scaffold: visit every Call in the body, resolve its op via
/// `GlobalVar` lookup, invoke @p on_match for each resolved (call, callee)
/// pair. Returning `true` from @p on_match terminates the walk early.
class CallVisitor : public IRVisitor {
public:
using OnMatchFn = std::function<bool(const CallPtr&, const FunctionPtr&)>;

CallVisitor(const ProgramPtr& program, OnMatchFn on_match)
: program_(program), on_match_(std::move(on_match)) {}

protected:
void VisitExpr_(const CallPtr& call) override {
if (stop_) return;
if (auto gv = As<GlobalVar>(call->op_)) {
if (auto callee = program_->GetFunction(gv->name_)) {
if (on_match_(call, callee)) {
stop_ = true;
return;
}
}
}
IRVisitor::VisitExpr_(call);
}

private:
const ProgramPtr& program_;
OnMatchFn on_match_;
bool stop_ = false;
};

} // namespace

WrapperCallInfo FindFirstInnerCall(const FunctionPtr& wrapper, const ProgramPtr& program) {
WrapperCallInfo info;
if (!wrapper || !wrapper->body_ || !program) return info;
CallVisitor visitor(program, [&](const CallPtr& call, const FunctionPtr& callee) {
info.inner_call = call;
info.inner_callee = callee;
return true; // first match wins; stop the walk
});
visitor.VisitStmt(wrapper->body_);
return info;
}

GroupCalleeInfo FindGroupCallees(const FunctionPtr& group_func, const ProgramPtr& program) {
GroupCalleeInfo info;
if (!group_func || !group_func->body_ || !program) return info;
CallVisitor visitor(program, [&](const CallPtr& call, const FunctionPtr& callee) {
if (callee->func_type_ == FunctionType::AIC && info.aic_name.empty()) {
info.aic_name = callee->name_;
if (!info.inner_call) {
info.inner_call = call;
info.inner_callee = callee;
}
} else if (callee->func_type_ == FunctionType::AIV && info.aiv_name.empty()) {
info.aiv_name = callee->name_;
if (!info.inner_call) {
info.inner_call = call;
info.inner_callee = callee;
}
} else if (callee->func_type_ == FunctionType::InCore && !info.inner_call) {
info.inner_call = call;
info.inner_callee = callee;
}
return false; // collect all matches
Comment thread
lyfne123 marked this conversation as resolved.
});
visitor.VisitStmt(group_func->body_);
return info;
}

std::vector<WrapperCallInfo> CollectInnerCalls(const FunctionPtr& wrapper, const ProgramPtr& program) {
std::vector<WrapperCallInfo> result;
if (!wrapper || !wrapper->body_ || !program) return result;
CallVisitor visitor(program, [&](const CallPtr& call, const FunctionPtr& callee) {
if (callee->func_type_ != FunctionType::Orchestration && callee->func_type_ != FunctionType::Opaque) {
result.push_back({call, callee});
}
return false;
});
visitor.VisitStmt(wrapper->body_);
return result;
}

} // namespace ir
} // namespace pypto
Loading