Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ set(PYPTO_SOURCES
src/ir/transforms/utils/parent_stmt_analysis.cpp
src/ir/transforms/utils/stmt_dependency_analysis.cpp
src/ir/transforms/utils/transform_utils.cpp
src/ir/transforms/utils/wrapper_call_utils.cpp
src/ir/transforms/visitor.cpp

# IR - Reporter
Expand Down
87 changes: 87 additions & 0 deletions include/pypto/ir/transforms/utils/wrapper_call_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) PyPTO Contributors.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
* -----------------------------------------------------------------------------------------------------------
*/

#ifndef PYPTO_IR_TRANSFORMS_UTILS_WRAPPER_CALL_UTILS_H_
#define PYPTO_IR_TRANSFORMS_UTILS_WRAPPER_CALL_UTILS_H_

#include <string>
#include <vector>

#include "pypto/ir/expr.h"
#include "pypto/ir/function.h"
#include "pypto/ir/program.h"

namespace pypto {
namespace ir {

/**
* @brief Result of a wrapper / inner-call lookup.
*
* Both fields are nullptr if no matching call was found.
*/
struct WrapperCallInfo {
CallPtr inner_call;
FunctionPtr inner_callee;
};

/**
* @brief Find the first non-builtin Call inside @p wrapper that resolves to a
* Function in @p program.
*
* "Non-builtin" here means the Call's op is a GlobalVar that names an
* existing user-level Function in the program. Builtin op calls
* (`tile.*`, `tensor.*`, `system.*`) carry no GlobalVar and are skipped.
*
* @return {call, callee} for the first match, or {nullptr, nullptr} if none.
*/
WrapperCallInfo FindFirstInnerCall(const FunctionPtr& wrapper, const ProgramPtr& program);

/**
* @brief Result of a Group-function callee scan.
*
* - `aic_name` / `aiv_name` — the names of the first AIC / AIV callees
* encountered (empty if none).
* - `inner_call` / `inner_callee` — the **first** AIC, AIV, or InCore call
* in source order, regardless of type. Used by orchestration codegen as
* the parameter-order reference for wrapper arg reconciliation. After
* `ExpandMixedKernel`, Group bodies are emitted as `AIC → AIV` so the
* AIC call is naturally first in practice; the function does not enforce
* a type priority.
*/
struct GroupCalleeInfo {
std::string aic_name;
std::string aiv_name;
CallPtr inner_call;
FunctionPtr inner_callee;
};

/**
* @brief Group-specific scan: locate the AIC / AIV callees and the first
* AIC/AIV/InCore inner call inside @p group_func.
*
* @return aggregated info; any field may be empty / nullptr if not present.
*/
GroupCalleeInfo FindGroupCallees(const FunctionPtr& group_func, const ProgramPtr& program);

/**
* @brief Collect every Call inside @p wrapper that resolves to a Function
* of a non-Orchestration, non-Opaque type.
*
* Used by cross-function direction propagation in `ComputeGroupEffectiveDirections`.
* Visits the body in order; each inner Call appears once even if its callee is
* called from multiple sites.
*/
std::vector<WrapperCallInfo> CollectInnerCalls(const FunctionPtr& wrapper, const ProgramPtr& program);

} // namespace ir
} // namespace pypto

#endif // PYPTO_IR_TRANSFORMS_UTILS_WRAPPER_CALL_UTILS_H_
28 changes: 4 additions & 24 deletions src/codegen/orchestration/orchestration_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "pypto/ir/stmt.h"
#include "pypto/ir/transforms/base/visitor.h"
#include "pypto/ir/transforms/utils/auto_name_utils.h"
#include "pypto/ir/transforms/utils/wrapper_call_utils.h"
#include "pypto/ir/type.h"

namespace pypto {
Expand Down Expand Up @@ -336,35 +337,14 @@ std::vector<ParamDirection> ComputeGroupEffectiveDirections(const FunctionPtr& g
return declared;
}

class InnerCallFinder : public IRVisitor {
public:
explicit InnerCallFinder(const ProgramPtr& program) : program_(program) {}
const ProgramPtr& program_;
std::vector<std::pair<CallPtr, FunctionPtr>> inner_calls;

protected:
void VisitExpr_(const CallPtr& call) override {
if (auto gv = As<GlobalVar>(call->op_)) {
auto callee = program_->GetFunction(gv->name_);
if (callee && callee->func_type_ != FunctionType::Orchestration &&
callee->func_type_ != FunctionType::Opaque) {
inner_calls.emplace_back(call, callee);
return;
}
}
IRVisitor::VisitExpr_(call);
}
};

InnerCallFinder finder(program);
finder.VisitStmt(func->body_);
if (!finder.inner_calls.empty()) {
auto inner_calls = ir::CollectInnerCalls(func, program);
if (!inner_calls.empty()) {
std::unordered_map<const Var*, size_t> param_to_index;
for (size_t i = 0; i < func->params_.size(); ++i) {
param_to_index[func->params_[i].get()] = i;
}

for (const auto& [inner_call, inner_callee] : finder.inner_calls) {
for (const auto& [inner_call, inner_callee] : inner_calls) {
const auto& inner_args = inner_call->args_;
std::vector<ParamDirection> inner_dirs;
if (inner_callee->func_type_ == FunctionType::Group ||
Expand Down
125 changes: 24 additions & 101 deletions src/codegen/orchestration/orchestration_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "pypto/ir/transforms/utils/op_predicates.h"
#include "pypto/ir/transforms/utils/transform_utils.h"
#include "pypto/ir/transforms/utils/var_collectors.h"
#include "pypto/ir/transforms/utils/wrapper_call_utils.h"
#include "pypto/ir/type.h"

namespace pypto {
Expand All @@ -56,43 +57,25 @@ namespace codegen {
using namespace pypto::ir; // NOLINT(build/namespaces)

CoreType InferFunctionCoreType(const FunctionPtr& func) {
if (func->func_type_ == FunctionType::AIC) return CoreType::CUBE;
if (func->func_type_ == FunctionType::AIV) return CoreType::VECTOR;

class CoreTypeCollector : public IRVisitor {
public:
bool has_cube_ = false;
bool has_vector_ = false;

void VisitExpr_(const CallPtr& call) override {
for (const auto& arg : call->args_) {
if (auto tile = As<TileType>(arg->GetType())) {
auto memory_space = tile->GetMemorySpace();
if (!memory_space.has_value()) {
continue;
}
if (IsCubeMemorySpace(*memory_space)) {
has_cube_ = true;
} else if (*memory_space == MemorySpace::Vec) {
has_vector_ = true;
}
}
}
IRVisitor::VisitExpr_(call);
}
};

CoreTypeCollector collector;
collector.VisitStmt(func->body_);

CHECK(!(collector.has_cube_ && collector.has_vector_))
<< "Function " << func->name_ << " contains both CUBE and VECTOR memory spaces. "
<< "A function can only use one core type.";

if (collector.has_cube_) {
return CoreType::CUBE;
// After ExpandMixedKernel runs (part of every Default / DebugTileOptimization
// pipeline), every InCore function reaching codegen has been split into AIC,
// AIV, or Group / Spmd wrappers. The two callers of this function
// (GenerateFunctionCallCode and GenerateSpmdCallCode) both filter Spmd /
// Group out before invoking it. Tests that bypass the pipeline must declare
// their kernels with the appropriate AIC / AIV type explicitly so codegen
// sees the concrete core type without re-deriving from body memory spaces.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
switch (func->func_type_) {
case FunctionType::AIC:
return CoreType::CUBE;
case FunctionType::AIV:
return CoreType::VECTOR;
default:
INTERNAL_UNREACHABLE_SPAN(func->span_)
<< "InferFunctionCoreType expects AIC or AIV (Spmd/Group are filtered upstream); got "
<< FunctionTypeToString(func->func_type_) << " on function '" << func->name_
<< "'. Either run ExpandMixedKernel before codegen or declare the function "
<< "with @pl.function(type=pl.FunctionType.AIC|AIV) directly.";
}
return CoreType::VECTOR;
}

namespace {
Expand Down Expand Up @@ -899,76 +882,16 @@ class OrchestrationStmtCodegen : public CodegenBase {
};

WrapperCallInfo FindWrapperInnerCall(const FunctionPtr& wrapper_func) {
class InnerCallFinder : public IRVisitor {
public:
explicit InnerCallFinder(const ProgramPtr& program) : program_(program) {}
const ProgramPtr& program_;
CallPtr inner_call;
FunctionPtr inner_callee;

protected:
void VisitExpr_(const CallPtr& call) override {
if (inner_call) return;
if (auto gv = As<GlobalVar>(call->op_)) {
auto callee = program_->GetFunction(gv->name_);
if (callee) {
inner_call = call;
inner_callee = callee;
return;
}
}
IRVisitor::VisitExpr_(call);
}
};

InnerCallFinder finder(program_);
finder.VisitStmt(wrapper_func->body_);
return {std::move(finder.inner_call), std::move(finder.inner_callee)};
auto info = ir::FindFirstInnerCall(wrapper_func, program_);
return {std::move(info.inner_call), std::move(info.inner_callee)};
}

/// Walk the Group function body to find the AIC and AIV callee names
/// and the inner InCore call (needed for param reordering).
GroupCalleeInfo FindGroupCallees(const FunctionPtr& group_func) {
class CalleeFinder : public IRVisitor {
public:
explicit CalleeFinder(const ProgramPtr& program) : program_(program) {}
const ProgramPtr& program_;
std::string aic_name;
std::string aiv_name;
CallPtr inner_call;
FunctionPtr inner_callee;

protected:
void VisitExpr_(const CallPtr& call) override {
if (auto gv = As<GlobalVar>(call->op_)) {
auto callee = program_->GetFunction(gv->name_);
if (callee) {
if (callee->func_type_ == FunctionType::AIC && aic_name.empty()) {
aic_name = callee->name_;
if (!inner_call) {
inner_call = call;
inner_callee = callee;
}
} else if (callee->func_type_ == FunctionType::AIV && aiv_name.empty()) {
aiv_name = callee->name_;
if (!inner_call) {
inner_call = call;
inner_callee = callee;
}
} else if (callee->func_type_ == FunctionType::InCore && !inner_call) {
inner_call = call;
inner_callee = callee;
}
}
}
IRVisitor::VisitExpr_(call);
}
};

CalleeFinder finder(program_);
finder.VisitStmt(group_func->body_);
return {std::move(finder.aic_name), std::move(finder.aiv_name), std::move(finder.inner_call),
std::move(finder.inner_callee)};
auto info = ir::FindGroupCallees(group_func, program_);
return {std::move(info.aic_name), std::move(info.aiv_name), std::move(info.inner_call),
std::move(info.inner_callee)};
}

/// Build task params for a wrapper function call, reordered to match the
Expand Down
Loading
Loading