@@ -633,6 +633,7 @@ void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
633633
634634 // Group generate: g_i OR (p_i AND g_j)
635635 Value andPG = comb::AndOp::create (builder, loc, pPrefix[i], gPrefix [j]);
636+ comb::AndOp::create (builder, loc, pPrefix[i], gPrefix [j]);
636637 gPrefixNew [i] = comb::OrOp::create (builder, loc, gPrefix [i], andPG);
637638
638639 // Group propagate: p_i AND p_j
@@ -648,8 +649,10 @@ void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
648649 int64_t j = i - stride;
649650
650651 // Group generate: g_i OR (p_i AND g_j)
651- Value andPG = comb::AndOp::create (builder, loc, pPrefix[i], gPrefix [j]);
652- gPrefixNew [i] = comb::OrOp::create (builder, loc, gPrefix [i], andPG);
652+ Value propagateAndGenerate =
653+ comb::AndOp::create (builder, loc, pPrefix[i], gPrefix [j]);
654+ gPrefixNew [i] =
655+ comb::OrOp::create (builder, loc, gPrefix [i], propagateAndGenerate);
653656
654657 // Group propagate: p_i AND p_j
655658 pPrefixNew[i] = comb::AndOp::create (builder, loc, pPrefix[i], pPrefix[j]);
@@ -700,6 +703,67 @@ void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
700703 });
701704}
702705
706+ // TODO: Generalize to other parallel prefix trees.
707+ class LazyKoggeStonePrefixTree {
708+ public:
709+ LazyKoggeStonePrefixTree (OpBuilder &builder, Location loc, int64_t width,
710+ ArrayRef<Value> pPrefix, ArrayRef<Value> gPrefix )
711+ : builder(builder), loc(loc), width(width) {
712+ assert (width > 0 && " width must be positive" );
713+ for (size_t i = 0 ; i < static_cast <size_t >(width); ++i)
714+ prefixCache[{0 , i}] = {pPrefix[i], gPrefix [i]};
715+ }
716+
717+ // Get the final group and propagate values for bit i.
718+ std::pair<Value, Value> getFinal (int64_t i) {
719+ assert (i >= 0 && i < width && " i out of bounds" );
720+ // Final level is ceil(log2(width)) in Kogge-Stone.
721+ return getGroupAndPropagate (llvm::Log2_64_Ceil (width), i);
722+ }
723+
724+ private:
725+ // Recursively get the group and propagate values for bit i at level `level`.
726+ // Level 0 is the initial level with the input propagate and generate values.
727+ // Level n computes the group and propagate values for a stride of 2^(n-1).
728+ // Uses memoization to cache intermediate results.
729+ std::pair<Value, Value> getGroupAndPropagate (int64_t level, int64_t i);
730+ OpBuilder &builder;
731+ Location loc;
732+ int64_t width;
733+ DenseMap<std::pair<int64_t , int64_t >, std::pair<Value, Value>> prefixCache;
734+ };
735+
736+ std::pair<Value, Value>
737+ LazyKoggeStonePrefixTree::getGroupAndPropagate (int64_t level, int64_t i) {
738+ assert (i < static_cast <int64_t >(width) && " i out of bounds" );
739+ auto key = std::make_pair (level, i);
740+ auto it = prefixCache.find (key);
741+ if (it != prefixCache.end ())
742+ return it->second ;
743+
744+ assert (level > 0 && " level must be positive" );
745+
746+ int64_t previousStride = 1ULL << (level - 1 );
747+ if (i < previousStride) {
748+ // No dependency, just copy from the previous level.
749+ auto [propagateI, generateI] = getGroupAndPropagate (level - 1 , i);
750+ prefixCache[key] = {propagateI, generateI};
751+ return prefixCache[key];
752+ }
753+ // Get the dependency index.
754+ int64_t j = i - previousStride;
755+ auto [propagateI, generateI] = getGroupAndPropagate (level - 1 , i);
756+ auto [propagateJ, generateJ] = getGroupAndPropagate (level - 1 , j);
757+ // Group generate: g_i OR (p_i AND g_j)
758+ Value andPG = comb::AndOp::create (builder, loc, propagateI, generateJ);
759+ Value newGenerate = comb::OrOp::create (builder, loc, generateI, andPG);
760+ // Group propagate: p_i AND p_j
761+ Value newPropagate =
762+ comb::AndOp::create (builder, loc, propagateI, propagateJ);
763+ prefixCache[key] = {newPropagate, newGenerate};
764+ return prefixCache[key];
765+ }
766+
703767template <bool lowerToMIG>
704768struct CombAddOpConversion : OpConversionPattern<AddOp> {
705769 using OpConversionPattern<AddOp>::OpConversionPattern;
@@ -1080,37 +1144,49 @@ struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
10801144 // need the final result. Optimizing this to skip intermediate computations
10811145 // is non-trivial because each iteration depends on results from previous
10821146 // iterations. We rely on DCE passes to remove unused operations.
1083- // TODO: Lazily compute only the required prefix values.
1147+ // TODO: Lazily compute only the required prefix values. Kogge-Stone is
1148+ // already implemented in a lazy manner below, but other architectures can
1149+ // also be optimized.
10841150 static Value computePrefixComparison (ConversionPatternRewriter &rewriter,
10851151 Location loc, SmallVector<Value> pPrefix,
10861152 SmallVector<Value> gPrefix ,
10871153 bool includeEq, AdderArchitecture arch) {
10881154 auto width = pPrefix.size ();
1155+ Value finalGroup, finalPropagate;
10891156 // Apply the appropriate prefix tree algorithm
10901157 switch (arch) {
10911158 case AdderArchitecture::RippleCarry:
10921159 llvm_unreachable (" Ripple-Carry should be handled separately" );
10931160 break ;
1094- case AdderArchitecture::Sklanskey:
1161+ case AdderArchitecture::Sklanskey: {
10951162 lowerSklanskeyPrefixTree (rewriter, loc, pPrefix, gPrefix );
1163+ finalGroup = gPrefix [width - 1 ];
1164+ finalPropagate = pPrefix[width - 1 ];
10961165 break ;
1166+ }
10971167 case AdderArchitecture::KoggeStone:
1098- lowerKoggeStonePrefixTree (rewriter, loc, pPrefix, gPrefix );
1168+ // Use lazy Kogge-Stone implementation to avoid computing all
1169+ // intermediate prefix values.
1170+ std::tie (finalPropagate, finalGroup) =
1171+ LazyKoggeStonePrefixTree (rewriter, loc, width, pPrefix, gPrefix )
1172+ .getFinal (width - 1 );
10991173 break ;
1100- case AdderArchitecture::BrentKung:
1174+ case AdderArchitecture::BrentKung: {
11011175 lowerBrentKungPrefixTree (rewriter, loc, pPrefix, gPrefix );
1176+ finalGroup = gPrefix [width - 1 ];
1177+ finalPropagate = pPrefix[width - 1 ];
11021178 break ;
11031179 }
1180+ }
11041181
11051182 // Final result: gPrefix[width-1] gives us "a < b"
11061183 if (includeEq) {
11071184 // a <= b iff (a < b) OR (a == b)
11081185 // a == b iff pPrefix[width-1] (all bits are equal)
1109- return comb::OrOp::create (rewriter, loc, gPrefix [width - 1 ],
1110- pPrefix[width - 1 ]);
1186+ return comb::OrOp::create (rewriter, loc, finalGroup, finalPropagate);
11111187 }
11121188 // a < b iff gPrefix[width-1]
1113- return gPrefix [width - 1 ] ;
1189+ return finalGroup ;
11141190 }
11151191
11161192 // Construct an unsigned comparator using either ripple-carry or
0 commit comments