Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions dawn/src/dawn/AST/ASTExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,23 @@ std::shared_ptr<Expr> ReductionOverNeighborExpr::clone() const {

ArrayRef<std::shared_ptr<Expr>> ReductionOverNeighborExpr::getChildren() const {
return ExprRangeType(operands_);
} // namespace ast
}

void ReductionOverNeighborExpr::setWeight(int idx, std::shared_ptr<Expr> weight) {
DAWN_ASSERT_MSG(hasWeights(), "set weights is only possible if there already are weights");
DAWN_ASSERT_MSG(idx >= 0 && idx < weights_->size(), "weight index out of range");
weights_->at(idx) = weight;
operands_.at(idx + 2) = weight;
}

void ReductionOverNeighborExpr::replaceChildren(const std::shared_ptr<Expr>& oldExpr,
const std::shared_ptr<Expr>& newExpr) {
[[maybe_unused]] bool success = ASTHelper::replaceOperands(oldExpr, newExpr, operands_);
DAWN_ASSERT_MSG((success), ("Expression not found"));
if(operands_.size() > 2) {
weights_ = std::vector<std::shared_ptr<Expr>>(operands_.begin() + 2, operands_.end());
}
}

bool ReductionOverNeighborExpr::equals(const Expr* other, bool compareData) const {
const ReductionOverNeighborExpr* otherPtr = dyn_cast<ReductionOverNeighborExpr>(other);
Expand All @@ -508,7 +524,8 @@ bool ReductionOverNeighborExpr::equals(const Expr* other, bool compareData) cons
return false;
}
for(int i = 0; i < weights_->size(); i++) {
if(*weights_.value().at(i) != *otherPtr->getWeights().value().at(i)) {
if(!(*weights_.value().at(i))
.equals((const ast::Expr*)otherPtr->getWeights().value().at(i).get(), compareData)) {
return false;
}
}
Expand All @@ -521,7 +538,7 @@ bool ReductionOverNeighborExpr::equals(const Expr* other, bool compareData) cons

bool ReductionOverNeighborExpr::isArithmetic() const {
return any_of(ast::ReductionOverNeighborExpr::arithmeticOps,
[&](std::string op) { return op_ == op; });
[&](std::string op) { return op_ == op; });
}

} // namespace ast
Expand Down
5 changes: 5 additions & 0 deletions dawn/src/dawn/AST/ASTExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "dawn/Support/Array.h"
#include "dawn/Support/ArrayRef.h"
#include "dawn/Support/Assert.h"
#include "dawn/Support/Type.h"
#include "dawn/Support/UIDGenerator.h"
#include <array>
Expand Down Expand Up @@ -672,10 +673,14 @@ class ReductionOverNeighborExpr : public Expr {
std::vector<ast::LocationType> getNbhChain() const { return iterSpace_; };
ast::LocationType getLhsLocation() const { return iterSpace_.Chain.front(); };
const std::optional<std::vector<std::shared_ptr<Expr>>>& getWeights() const { return weights_; };
bool hasWeights() const { return weights_.has_value(); }
void setWeight(int idx, std::shared_ptr<Expr> weight);
bool getIncludeCenter() const { return iterSpace_.IncludeCenter; };
ast::UnstructuredIterationSpace getIterSpace() const { return iterSpace_; }

ExprRangeType getChildren() const override;
virtual void replaceChildren(const std::shared_ptr<Expr>& oldExpr,
const std::shared_ptr<Expr>& newExpr) override;

static bool classof(const Expr* expr) {
return expr->getKind() == Kind::ReductionOverNeighborExpr;
Expand Down
7 changes: 7 additions & 0 deletions dawn/src/dawn/AST/ASTUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ class ExprReplacer : public ASTVisitorNonConst {
void visit(const std::shared_ptr<StencilCallDeclStmt>& stmt) override {}
void visit(const std::shared_ptr<BoundaryConditionDeclStmt>& stmt) override {}
void visit(const std::shared_ptr<ReductionOverNeighborExpr>& expr) override {
if(expr->hasWeights()) {
for(int i = 0; i < expr->getWeights()->size(); i++) {
if(expr->getWeights()->at(i)->equals(oldExpr_)) {
expr->setWeight(i, newExpr_);
}
}
}
for(const auto& s : expr->getChildren())
s->accept(*this);
}
Expand Down
5 changes: 5 additions & 0 deletions dawn/src/dawn/IIR/ASTStmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include "dawn/AST/ASTExpr.h"
#include "dawn/AST/ASTStmt.h"
#include "dawn/IIR/Accesses.h"
#include <memory>
Expand Down Expand Up @@ -94,6 +95,10 @@ std::shared_ptr<ast::VarDeclStmt> makeVarDeclStmt(Args&&... args) {
std::forward<Args>(args)...);
}
template <typename... Args>
std::shared_ptr<ast::ReductionOverNeighborExpr> makeReductionOverNeighborExpr(Args&&... args) {
return std::make_shared<ast::ReductionOverNeighborExpr>(std::forward<Args>(args)...);
}
template <typename... Args>
std::shared_ptr<ast::VerticalRegionDeclStmt> makeVerticalRegionDeclStmt(Args&&... args) {
return std::make_shared<ast::VerticalRegionDeclStmt>(std::make_unique<IIRStmtData>(),
std::forward<Args>(args)...);
Expand Down
2 changes: 2 additions & 0 deletions dawn/src/dawn/Optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ add_library(DawnOptimizer
PassFixVersionedInputFields.h
PassInlining.cpp
PassInlining.h
PassTemporaryInlining.h
PassTemporaryInlining.cpp
PassIntervalPartitioning.cpp
PassIntervalPartitioning.h
PassLocalVarType.h
Expand Down
7 changes: 7 additions & 0 deletions dawn/src/dawn/Optimizer/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include "dawn/CodeGen/Driver.h"
#include "dawn/CodeGen/TranslationUnit.h"
#include "dawn/Optimizer/Lowering.h"
#include "dawn/Optimizer/Options.h"
#include "dawn/Optimizer/PassManager.h"
#include "dawn/Optimizer/PassTemporaryInlining.h"
#include "dawn/SIR/SIR.h"
#include "dawn/Support/Exception.h"
#include "dawn/Support/Logger.h"
Expand Down Expand Up @@ -260,6 +262,11 @@ run(const std::map<std::string, std::shared_ptr<iir::StencilInstantiation>>&
// validation check
passManager.pushBackPass<PassValidation>();
break;
case PassGroup::TemporaryInlining:
passManager.pushBackPass<PassTemporaryInlining>();
// validation check
passManager.pushBackPass<PassValidation>();
break;
case PassGroup::Parallel:
DAWN_ASSERT_MSG(false, "The parallel group is only valid for lowering to IIR.");
}
Expand Down
1 change: 1 addition & 0 deletions dawn/src/dawn/Optimizer/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ enum class PassGroup {
SetBlockSize,
DataLocalityMetric,
SetLoopOrder,
TemporaryInlining
};

struct Options {
Expand Down
220 changes: 220 additions & 0 deletions dawn/src/dawn/Optimizer/PassTemporaryInlining.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
//===--------------------------------------------------------------------------------*- C++ -*-===//
// _
// | |
// __| | __ ___ ___ ___
// / _` |/ _` \ \ /\ / / '_ |
// | (_| | (_| |\ V V /| | | |
// \__,_|\__,_| \_/\_/ |_| |_| - Compiler Toolchain
//
//
// This file is distributed under the MIT License (MIT).
// See LICENSE.txt for details.
//
//===------------------------------------------------------------------------------------------===//

#include "dawn/Optimizer/PassTemporaryInlining.h"
#include "dawn/AST/ASTExpr.h"
#include "dawn/AST/ASTStmt.h"
#include "dawn/IIR/ASTExpr.h"
#include "dawn/IIR/ASTMatcher.h"
#include "dawn/IIR/AccessComputation.h"
#include "dawn/IIR/StencilInstantiation.h"
#include "dawn/Support/Logger.h"
#include "dawn/Support/Unreachable.h"
#include <memory>
#include <unordered_map>
#include <variant>

namespace {
void cleanUp(const std::shared_ptr<dawn::iir::StencilInstantiation>& stencilInstantiation) {
for(const auto& multiStage :
dawn::iterateIIROver<dawn::iir::MultiStage>(*stencilInstantiation->getIIR())) {
for(auto curStageIt = multiStage->childrenBegin(); curStageIt != multiStage->childrenEnd();
curStageIt++) {
dawn::iir::Stage& curStage = **curStageIt;
for(auto curDoMethodIt = curStage.childrenBegin(); curDoMethodIt != curStage.childrenEnd();) {
dawn::iir::DoMethod& curDoMethod = **curDoMethodIt;

if(curDoMethod.isEmptyOrNullStmt()) {
DAWN_LOG(INFO) << stencilInstantiation->getName() << ": DoMethod: " << curDoMethod.getID()
<< " has empty body after inlining a temporary, removing";

curDoMethodIt = curStage.childrenErase(curDoMethodIt);
} else {
curDoMethodIt++;
}
}

for(auto& doMethod : curStage.getChildren()) {
doMethod->update(dawn::iir::NodeUpdateType::level);
}
curStage.update(dawn::iir::NodeUpdateType::levelAndTreeAbove);
}

for(auto curStageIt = multiStage->childrenBegin(); curStageIt != multiStage->childrenEnd();) {
dawn::iir::Stage& curStage = **curStageIt;
if(curStage.childrenEmpty()) {
curStageIt = multiStage->childrenErase(curStageIt);
} else {
curStageIt++;
}
}
}
}
} // namespace

namespace dawn {

class AccessesReplacer : public ast::ASTVisitorPostOrder {
const int accessToSubstitute_;
const std::shared_ptr<const ast::Expr> substituteExpr_;

public:
std::shared_ptr<ast::Expr>
postVisitNode(std::shared_ptr<ast::FieldAccessExpr> const& expr) override {
if(iir::getAccessID(expr) == accessToSubstitute_) {
return substituteExpr_->clone();
}
return expr;
}

std::shared_ptr<ast::Expr>
postVisitNode(std::shared_ptr<ast::VarAccessExpr> const& expr) override {
if(iir::getAccessID(expr) == accessToSubstitute_) {
return substituteExpr_->clone();
}
return expr;
}

AccessesReplacer(int accessToSubstitute, const std::shared_ptr<const ast::Expr> substituteExpr)
: accessToSubstitute_(accessToSubstitute), substituteExpr_(substituteExpr) {}
};

bool PassTemporaryInlining::run(
const std::shared_ptr<iir::StencilInstantiation>& stencilInstantiation,
const Options& options) {

using candidate_t =
std::variant<std::shared_ptr<ast::AssignmentExpr>, std::shared_ptr<ast::VarDeclStmt>>;

dawn::iir::ASTMatcher assignMatcher(stencilInstantiation.get());
dawn::iir::ASTMatcher varDeclMatcher(stencilInstantiation.get());
std::vector<std::shared_ptr<ast::Expr>>& assignmentsExprs =
assignMatcher.match(ast::Expr::Kind::AssignmentExpr);
std::vector<std::shared_ptr<ast::Stmt>>& varDeclStmts =
varDeclMatcher.match(ast::Stmt::Kind::VarDeclStmt);

std::unordered_map<int, int> occCount;
for(const auto& varDeclStmtIt : varDeclStmts) {
const auto& cand = std::static_pointer_cast<ast::VarDeclStmt>(varDeclStmtIt);
occCount[iir::getAccessID(cand)]++;
}
for(const auto& assignmentsExpr : assignmentsExprs) {
const auto& cand = std::static_pointer_cast<ast::BinaryOperator>(assignmentsExpr);
if(auto lhs = std::dynamic_pointer_cast<ast::FieldAccessExpr>(cand->getLeft())) {
occCount[iir::getAccessID(lhs)]++;
}
if(auto lhs = std::dynamic_pointer_cast<ast::VarAccessExpr>(cand->getLeft())) {
occCount[iir::getAccessID(lhs)]++;
}
}

std::vector<candidate_t> candidates;
for(const auto& it : assignmentsExprs) {
const auto& cand = std::static_pointer_cast<ast::BinaryOperator>(it);
if(auto lhs = std::dynamic_pointer_cast<ast::FieldAccessExpr>(cand->getLeft())) {
if(stencilInstantiation->getMetaData().isAccessType(iir::FieldAccessType::StencilTemporary,
iir::getAccessID(lhs)) &&
ast::dimension_cast<const ast::UnstructuredFieldDimension&>(
stencilInstantiation->getMetaData()
.getFieldDimensions(iir::getAccessID(lhs))
.getHorizontalFieldDimension())
.isDense() &&
occCount.at(iir::getAccessID(lhs)) == 1) {
candidates.push_back(std::static_pointer_cast<ast::AssignmentExpr>(cand));
DAWN_LOG(INFO) << stencilInstantiation->getName()
<< " inlined computation of temporary Field"
<< stencilInstantiation->getMetaData().getNameFromAccessID(
iir::getAccessID(lhs));
}
}
}
for(const auto& varDeclStmtIt : varDeclStmts) {
const auto& cand = std::static_pointer_cast<ast::VarDeclStmt>(varDeclStmtIt);
if(occCount.at(iir::getAccessID(cand)) == 1 && cand->getInitList().size() == 1) {
candidates.push_back(cand);
DAWN_LOG(INFO) << stencilInstantiation->getName() << " inlined computation of var "
<< stencilInstantiation->getMetaData().getNameFromAccessID(
iir::getAccessID(cand));
}
}

auto getFieldOrVarAccessID = [](const candidate_t cand) -> int {
if(auto assignment = std::get_if<std::shared_ptr<ast::AssignmentExpr>>(&cand)) {
auto lhs = std::static_pointer_cast<ast::AssignmentExpr>(assignment->get()->getLeft());
return iir::getAccessID(lhs);
}
if(auto varDecl = std::get_if<std::shared_ptr<ast::VarDeclStmt>>(&cand)) {
return iir::getAccessID(*varDecl);
}
dawn_unreachable("invalid candidate");
};

auto getRhs = [](const candidate_t cand) -> std::shared_ptr<ast::Expr> {
if(auto assignment = std::get_if<std::shared_ptr<ast::AssignmentExpr>>(&cand)) {
return assignment->get()->getRight();
}
if(const auto& varDecl = std::get_if<std::shared_ptr<ast::VarDeclStmt>>(&cand)) {
DAWN_ASSERT(varDecl->get()->getInitList().size() == 1);
return varDecl->get()->getInitList()[0];
}
dawn_unreachable("invalid candidate");
};

for(auto cand : candidates) {
int accessID = getFieldOrVarAccessID(cand);
AccessesReplacer replacer(accessID, getRhs(cand));

for(auto& doMethod : iterateIIROver<iir::DoMethod>(*stencilInstantiation->getIIR())) {
for(auto stmtIt = doMethod->getAST().getStatements().begin();
stmtIt != doMethod->getAST().getStatements().end();) {

if((*stmtIt)->getKind() == ast::Stmt::Kind::ExprStmt) {
const auto& exprStmt = std::static_pointer_cast<ast::ExprStmt>(*stmtIt);
if(const auto& binaryOp =
std::dynamic_pointer_cast<ast::BinaryOperator>(exprStmt->getExpr())) {
if(iir::getAccessID(binaryOp->getLeft()) == accessID) {
doMethod->getAST().erase(stmtIt);
continue;
}
}
}

if((*stmtIt)->getKind() == ast::Stmt::Kind::VarDeclStmt) {
const auto& varDeclStmt = std::static_pointer_cast<ast::VarDeclStmt>(*stmtIt);
if(iir::getAccessID(varDeclStmt) == accessID) {
doMethod->getAST().erase(stmtIt);
continue;
}
}

(*stmtIt)->acceptAndReplace(replacer);
stmtIt++;
}

computeAccesses(stencilInstantiation->getMetaData(), doMethod->getAST().getStatements());
doMethod->update(iir::NodeUpdateType::levelAndTreeAbove);
}
}

for(auto cand : candidates) {
int accessID = getFieldOrVarAccessID(cand);
stencilInstantiation->getMetaData().removeAccessID(accessID);
}

cleanUp(stencilInstantiation);

return true;
}

} // namespace dawn
Loading