Skip to content

Commit 5e8d4ef

Browse files
committed
[flang] loop interchange from kruse
1 parent a517d8c commit 5e8d4ef

35 files changed

+1362
-70
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,10 @@ class AbstractConverter {
401401

402402
virtual mlir::StateStack &getStateStack() = 0;
403403

404+
virtual void
405+
genPermutatedLoops(llvm::ArrayRef<Fortran::lower::pft::Evaluation *> doStmts,
406+
Fortran::lower::pft::Evaluation *innermostDo) = 0;
407+
404408
private:
405409
/// Options controlling lowering behavior.
406410
const Fortran::lower::LoweringOptions &loweringOptions;

flang/include/flang/Lower/OpenMP/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ void collectTileSizesFromOpenMPConstruct(
187187
llvm::SmallVectorImpl<int64_t> &tileSizes,
188188
Fortran::semantics::SemanticsContext &semaCtx);
189189

190+
void collectPermutationFromOpenMPConstruct(
191+
const parser::OpenMPConstruct *ompCons,
192+
llvm::SmallVectorImpl<int64_t> &permuation,
193+
Fortran::semantics::SemanticsContext &semaCtx);
194+
190195
} // namespace omp
191196
} // namespace lower
192197
} // namespace Fortran

flang/lib/Lower/Allocatable.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "flang/Runtime/pointer.h"
3838
#include "flang/Semantics/tools.h"
3939
#include "flang/Semantics/type.h"
40+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
4041
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
4142
#include "llvm/Support/CommandLine.h"
4243

flang/lib/Lower/Bridge.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,6 +2268,124 @@ class FirConverter : public Fortran::lower::AbstractConverter {
22682268
// so no clean-up needs to be generated for these entities.
22692269
}
22702270

2271+
void
2272+
genPermutatedLoops(llvm::ArrayRef<Fortran::lower::pft::Evaluation *> doStmts,
2273+
Fortran::lower::pft::Evaluation *innermostDo) override {
2274+
// Fortran::lower::pft::Evaluation &eval = getEval();
2275+
// bool unstructuredContext = eval.lowerAsUnstructured();
2276+
2277+
llvm::SmallVector<mlir::Block *> headerBlocks;
2278+
llvm::SmallVector<IncrementLoopNestInfo, 1> loopInfos;
2279+
2280+
auto enterLoop = [&](Fortran::lower::pft::Evaluation &eval) {
2281+
bool unstructuredContext = eval.lowerAsUnstructured();
2282+
2283+
// Collect loop nest information.
2284+
// Generate begin loop code directly for infinite and while loops.
2285+
Fortran::lower::pft::Evaluation &doStmtEval =
2286+
eval.getFirstNestedEvaluation();
2287+
auto *doStmt = doStmtEval.getIf<Fortran::parser::NonLabelDoStmt>();
2288+
const auto &loopControl =
2289+
std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
2290+
mlir::Block *preheaderBlock = doStmtEval.block;
2291+
mlir::Block *beginBlock =
2292+
preheaderBlock ? preheaderBlock : builder->getBlock();
2293+
auto createNextBeginBlock = [&]() {
2294+
// Step beginBlock through unstructured preheader, header, and mask
2295+
// blocks, created in outermost to innermost order.
2296+
return beginBlock = beginBlock->splitBlock(beginBlock->end());
2297+
};
2298+
mlir::Block *headerBlock =
2299+
unstructuredContext ? createNextBeginBlock() : nullptr;
2300+
headerBlocks.push_back(headerBlock);
2301+
mlir::Block *bodyBlock = doStmtEval.lexicalSuccessor->block;
2302+
mlir::Block *exitBlock = doStmtEval.parentConstruct->constructExit->block;
2303+
IncrementLoopNestInfo &incrementLoopNestInfo = loopInfos.emplace_back();
2304+
const Fortran::parser::ScalarLogicalExpr *whileCondition = nullptr;
2305+
bool infiniteLoop = !loopControl.has_value();
2306+
if (infiniteLoop) {
2307+
assert(unstructuredContext && "infinite loop must be unstructured");
2308+
startBlock(headerBlock);
2309+
} else if ((whileCondition =
2310+
std::get_if<Fortran::parser::ScalarLogicalExpr>(
2311+
&loopControl->u))) {
2312+
assert(unstructuredContext && "while loop must be unstructured");
2313+
maybeStartBlock(preheaderBlock); // no block or empty block
2314+
startBlock(headerBlock);
2315+
genConditionalBranch(*whileCondition, bodyBlock, exitBlock);
2316+
} else if (const auto *bounds =
2317+
std::get_if<Fortran::parser::LoopControl::Bounds>(
2318+
&loopControl->u)) {
2319+
// Non-concurrent increment loop.
2320+
IncrementLoopInfo &info = incrementLoopNestInfo.emplace_back(
2321+
*bounds->name.thing.symbol, bounds->lower, bounds->upper,
2322+
bounds->step);
2323+
if (unstructuredContext) {
2324+
maybeStartBlock(preheaderBlock);
2325+
info.hasRealControl = info.loopVariableSym->GetType()->IsNumeric(
2326+
Fortran::common::TypeCategory::Real);
2327+
info.headerBlock = headerBlock;
2328+
info.bodyBlock = bodyBlock;
2329+
info.exitBlock = exitBlock;
2330+
}
2331+
} else {
2332+
llvm_unreachable("Cannot permute DO CONCURRENT");
2333+
}
2334+
2335+
// Increment loop begin code. (Infinite/while code was already generated.)
2336+
if (!infiniteLoop && !whileCondition)
2337+
genFIRIncrementLoopBegin(incrementLoopNestInfo, doStmtEval.dirs);
2338+
};
2339+
2340+
auto leaveLoop = [&](Fortran::lower::pft::Evaluation &eval,
2341+
mlir::Block *headerBlock,
2342+
IncrementLoopNestInfo &incrementLoopNestInfo) {
2343+
bool unstructuredContext = eval.lowerAsUnstructured();
2344+
2345+
Fortran::lower::pft::Evaluation &doStmtEval =
2346+
eval.getFirstNestedEvaluation();
2347+
auto *doStmt = doStmtEval.getIf<Fortran::parser::NonLabelDoStmt>();
2348+
2349+
const auto &loopControl =
2350+
std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
2351+
bool infiniteLoop = !loopControl.has_value();
2352+
const Fortran::parser::ScalarLogicalExpr *whileCondition =
2353+
std::get_if<Fortran::parser::ScalarLogicalExpr>(&loopControl->u);
2354+
2355+
auto iter = std::prev(eval.getNestedEvaluations().end());
2356+
2357+
// An EndDoStmt in unstructured code may start a new block.
2358+
Fortran::lower::pft::Evaluation &endDoEval = *iter;
2359+
assert(endDoEval.getIf<Fortran::parser::EndDoStmt>() && "no enddo stmt");
2360+
if (unstructuredContext)
2361+
maybeStartBlock(endDoEval.block);
2362+
2363+
// Loop end code.
2364+
if (infiniteLoop || whileCondition)
2365+
genBranch(headerBlock);
2366+
else
2367+
genFIRIncrementLoopEnd(incrementLoopNestInfo);
2368+
2369+
// This call may generate a branch in some contexts.
2370+
genFIR(endDoEval, unstructuredContext);
2371+
};
2372+
2373+
for (auto l : doStmts)
2374+
enterLoop(*l);
2375+
2376+
// Loop body code.
2377+
bool innermostUnstructuredContext = innermostDo->lowerAsUnstructured();
2378+
2379+
auto iter = innermostDo->getNestedEvaluations().begin();
2380+
for (auto end = --innermostDo->getNestedEvaluations().end(); iter != end;
2381+
++iter)
2382+
genFIR(*iter, innermostUnstructuredContext);
2383+
2384+
for (auto &&[l, headerBlock, li] :
2385+
llvm::zip_equal(doStmts, headerBlocks, loopInfos))
2386+
leaveLoop(*l, headerBlock, li);
2387+
}
2388+
22712389
void attachInlineAttributes(
22722390
mlir::Operation &op,
22732391
const llvm::ArrayRef<const Fortran::parser::CompilerDirective *> &dirs) {

flang/lib/Lower/OpenMP/Decomposer.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,49 @@ ConstructQueue buildConstructQueue(
9898
return decompose.output;
9999
}
100100

101+
// from clang
102+
// There is a copy in check-omp-loops.cpp
103+
static bool isOpenMPLoopTransformationDirective(llvm::omp::Directive dir) {
104+
switch (dir) {
105+
// TODO case llvm::omp::Directive::OMPD_flatten:
106+
case llvm::omp::Directive::OMPD_fuse:
107+
case llvm::omp::Directive::OMPD_interchange:
108+
// case llvm::omp::Directive::OMPD_nothing:
109+
case llvm::omp::Directive::OMPD_reverse:
110+
// TODO case llvm::omp::Directive::OMPD_split:
111+
case llvm::omp::Directive::OMPD_stripe:
112+
case llvm::omp::Directive::OMPD_tile:
113+
case llvm::omp::Directive::OMPD_unroll:
114+
return true;
115+
default:
116+
return false;
117+
}
118+
}
119+
120+
llvm::iterator_range<ConstructQueue::const_iterator> getNonTransformQueue(
121+
llvm::iterator_range<ConstructQueue::const_iterator> range) {
122+
// remove trailing loop transformations
123+
auto b = range.begin();
124+
auto e = range.end();
125+
while (e != b) {
126+
auto e2 = e - 1;
127+
if (!isOpenMPLoopTransformationDirective(e2->id))
128+
break;
129+
e = e2;
130+
}
131+
132+
return llvm::make_range(b, e);
133+
}
134+
101135
bool matchLeafSequence(ConstructQueue::const_iterator item,
102136
const ConstructQueue &queue,
103137
llvm::omp::Directive directive) {
104138
llvm::ArrayRef<llvm::omp::Directive> leafDirs =
105139
llvm::omp::getLeafConstructsOrSelf(directive);
106140

107-
for (auto [dir, leaf] :
108-
llvm::zip_longest(leafDirs, llvm::make_range(item, queue.end()))) {
141+
for (auto [dir, leaf] : llvm::zip_longest(
142+
leafDirs,
143+
getNonTransformQueue(llvm::make_range(item, queue.end())))) {
109144
if (!dir.has_value() || !leaf.has_value())
110145
return false;
111146

flang/lib/Lower/OpenMP/Decomposer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ bool isLastItemInQueue(ConstructQueue::const_iterator item,
5757
bool matchLeafSequence(ConstructQueue::const_iterator item,
5858
const ConstructQueue &queue,
5959
llvm::omp::Directive directive);
60+
61+
llvm::iterator_range<ConstructQueue::const_iterator> getNonTransformQueue(
62+
llvm::iterator_range<ConstructQueue::const_iterator> range);
63+
6064
} // namespace Fortran::lower::omp
6165

6266
#endif // FORTRAN_LOWER_OPENMP_DECOMPOSER_H

0 commit comments

Comments
 (0)