diff --git a/experimental/lib/Support/FtdImplementation.cpp b/experimental/lib/Support/FtdImplementation.cpp index 1ac75bd076..a6519a7ee5 100644 --- a/experimental/lib/Support/FtdImplementation.cpp +++ b/experimental/lib/Support/FtdImplementation.cpp @@ -676,6 +676,18 @@ getLoopExitCondition(CFGLoop *loop, std::vector *cofactorList, return fLoopExit; } +// Helper: create a conversion to the desired type. Uses +// UnrealizedConversionCastOp as a generic no-op conversion placeholder. +static Value ensureType(PatternRewriter &rewriter, Location loc, Value v, + Type desiredTy) { + if (v.getType() == desiredTy) + return v; + // Use UnrealizedConversionCastOp to produce a Value of `desiredTy`. + auto cast = rewriter.create( + loc, ArrayRef{desiredTy}, v); + return cast.getResult(0); +} + void ftd::addRegenOperandConsumer(PatternRewriter &rewriter, handshake::FuncOp &funcOp, Operation *consumerOp, Value operand) { @@ -726,9 +738,9 @@ void ftd::addRegenOperandConsumer(PatternRewriter &rewriter, auto cstType = rewriter.getIntegerType(1); auto cstAttr = IntegerAttr::get(cstType, 0); - auto createRegenMux = [&](CFGLoop *loop) -> handshake::MuxOp { + auto createRegenMux = [&](CFGLoop *loop) -> Value { rewriter.setInsertionPointToStart(loop->getHeader()); - regeneratedValue.setType(channelifyType(regeneratedValue.getType())); + Location loc = consumerOp->getLoc(); // Determine the loop exit condition: // - If the condition spans multiple cofactors, build a BDD and @@ -738,50 +750,48 @@ void ftd::addRegenOperandConsumer(PatternRewriter &rewriter, std::vector cofactorList; BoolExpression *exitCondition = getLoopExitCondition(loop, &cofactorList, loopInfo, bi); - if (size(cofactorList) > 1) { + if (cofactorList.size() > 1) { BDD *bdd = buildBDD(exitCondition, cofactorList); conditionValue = bddToCircuit(rewriter, bdd, loop->getHeader(), bi, false); } else conditionValue = loop->getExitingBlock()->getTerminator()->getOperand(0); - // Create the false constant to feed `init` - auto constOp = rewriter.create(consumerOp->getLoc(), - cstAttr, startValue); + // Determine data channel type for regeneratedValue + Type dataChanTy = channelifyType(regeneratedValue.getType()); + + auto constOp = + rewriter.create(loc, cstAttr, startValue); constOp->setAttr(FTD_INIT_MERGE, rewriter.getUnitAttr()); + Value initConst = constOp.getResult(); - // Create the `init` operation - SmallVector mergeOperands = {constOp.getResult(), conditionValue}; - auto initMergeOp = rewriter.create(consumerOp->getLoc(), - mergeOperands); + // Ensure condition is in the correct form for merge operands: + Type condChanTy = channelifyType(conditionValue.getType()); + Value condChan = ensureType(rewriter, loc, conditionValue, condChanTy); + + SmallVector mergeOperands = {initConst, condChan}; + auto initMergeOp = rewriter.create(loc, mergeOperands); initMergeOp->setAttr(FTD_INIT_MERGE, rewriter.getUnitAttr()); - // The multiplexer is to be fed by the init block, and takes as inputs the - // regenerated value and the result itself (to be set after) it was created. - auto selectSignal = initMergeOp.getResult(); - selectSignal.setType(channelifyType(selectSignal.getType())); + Value selectResult = initMergeOp.getResult(); + Value chanRegVal = ensureType(rewriter, loc, regeneratedValue, dataChanTy); - SmallVector muxOperands = {regeneratedValue, regeneratedValue}; - auto muxOp = rewriter.create(regeneratedValue.getLoc(), - regeneratedValue.getType(), - selectSignal, muxOperands); + auto muxOp = rewriter.create( + loc, dataChanTy, selectResult, + SmallVector{chanRegVal, chanRegVal}); muxOp->setOperand(2, muxOp->getResult(0)); muxOp->setAttr(FTD_REGEN, rewriter.getUnitAttr()); - return muxOp; + return muxOp.getResult(); }; // For each of the loop, from the outermost to the innermost - for (unsigned i = 0; i < numberOfLoops; i++) { - - // If we are in the innermost loop (thus the iterator is at its end) - // and the consumer is a loop merge, stop + for (unsigned i = 0; i < numberOfLoops; ++i) { if (i == numberOfLoops - 1 && consumerOp->hasAttr(NEW_PHI)) break; - - auto muxOp = createRegenMux(loops[i]); - regeneratedValue = muxOp.getResult(); + Value newReg = createRegenMux(loops[i]); + regeneratedValue = newReg; } // Final replace the usage of the operand in the consumer with the output of