Skip to content

Commit ec10b65

Browse files
committed
[CombToSynth] Compute Kogge-Stone prefix tree lazily in unsigned comparison lowering
Add LazyKoggeStonePrefixTree class for on-demand computation of prefix values, use lazy evaluation in ICmpOp conversion to avoid computing all intermediate prefix values.
1 parent e589add commit ec10b65

File tree

2 files changed

+91
-14
lines changed

2 files changed

+91
-14
lines changed

integration_test/circt-synth/comb-lowering-compare.mlir

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ hw.module @icmp_unsigned_sklanskey(in %lhs: i3, in %rhs: i3, out out_ugt: i1, ou
2525

2626
// RUN: circt-lec %t.mlir %s -c1=icmp_unsigned_kogge_stone -c2=icmp_unsigned_kogge_stone --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ICMP_UNSIGNED_KOGGE_STONE
2727
// COMB_ICMP_UNSIGNED_KOGGE_STONE: c1 == c2
28-
hw.module @icmp_unsigned_kogge_stone(in %lhs: i3, in %rhs: i3, out out_ugt: i1, out out_uge: i1, out out_ult: i1, out out_ule: i1) {
29-
%ugt = comb.icmp ugt %lhs, %rhs {synth.test.arch = "KOGGE-STONE"} : i3
30-
%uge = comb.icmp uge %lhs, %rhs {synth.test.arch = "KOGGE-STONE"} : i3
31-
%ult = comb.icmp ult %lhs, %rhs {synth.test.arch = "KOGGE-STONE"} : i3
32-
%ule = comb.icmp ule %lhs, %rhs {synth.test.arch = "KOGGE-STONE"} : i3
28+
// Use slightly larger width to verify the lazy prefix tree logic
29+
hw.module @icmp_unsigned_kogge_stone(in %lhs: i14, in %rhs: i14, out out_ugt: i1, out out_uge: i1, out out_ult: i1, out out_ule: i1) {
30+
%ugt = comb.icmp ugt %lhs, %rhs {synth.test.arch = "KOGGE-STONE"} : i14
31+
%uge = comb.icmp uge %lhs, %rhs {synth.test.arch = "KOGGE-STONE"} : i14
32+
%ult = comb.icmp ult %lhs, %rhs {synth.test.arch = "KOGGE-STONE"} : i14
33+
%ule = comb.icmp ule %lhs, %rhs {synth.test.arch = "KOGGE-STONE"} : i14
3334
hw.output %ugt, %uge, %ult, %ule : i1, i1, i1, i1
3435
}
3536

lib/Conversion/CombToSynth/CombToSynth.cpp

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,7 @@ void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
633633

634634
// Group generate: g_i OR (p_i AND g_j)
635635
Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
636+
comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
636637
gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
637638

638639
// Group propagate: p_i AND p_j
@@ -648,8 +649,10 @@ void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
648649
int64_t j = i - stride;
649650

650651
// Group generate: g_i OR (p_i AND g_j)
651-
Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
652-
gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
652+
Value propagateAndGenerate =
653+
comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
654+
gPrefixNew[i] =
655+
comb::OrOp::create(builder, loc, gPrefix[i], propagateAndGenerate);
653656

654657
// Group propagate: p_i AND p_j
655658
pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
@@ -700,6 +703,67 @@ void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
700703
});
701704
}
702705

706+
// TODO: Generalize to other parallel prefix trees.
707+
class LazyKoggeStonePrefixTree {
708+
public:
709+
LazyKoggeStonePrefixTree(OpBuilder &builder, Location loc, int64_t width,
710+
ArrayRef<Value> pPrefix, ArrayRef<Value> gPrefix)
711+
: builder(builder), loc(loc), width(width) {
712+
assert(width > 0 && "width must be positive");
713+
for (size_t i = 0; i < static_cast<size_t>(width); ++i)
714+
prefixCache[{0, i}] = {pPrefix[i], gPrefix[i]};
715+
}
716+
717+
// Get the final group and propagate values for bit i.
718+
std::pair<Value, Value> getFinal(int64_t i) {
719+
assert(i >= 0 && i < width && "i out of bounds");
720+
// Final level is ceil(log2(width)) in Kogge-Stone.
721+
return getGroupAndPropagate(llvm::Log2_64_Ceil(width), i);
722+
}
723+
724+
private:
725+
// Recursively get the group and propagate values for bit i at level `level`.
726+
// Level 0 is the initial level with the input propagate and generate values.
727+
// Level n computes the group and propagate values for a stride of 2^(n-1).
728+
// Uses memoization to cache intermediate results.
729+
std::pair<Value, Value> getGroupAndPropagate(int64_t level, int64_t i);
730+
OpBuilder &builder;
731+
Location loc;
732+
int64_t width;
733+
DenseMap<std::pair<int64_t, int64_t>, std::pair<Value, Value>> prefixCache;
734+
};
735+
736+
std::pair<Value, Value>
737+
LazyKoggeStonePrefixTree::getGroupAndPropagate(int64_t level, int64_t i) {
738+
assert(i < static_cast<int64_t>(width) && "i out of bounds");
739+
auto key = std::make_pair(level, i);
740+
auto it = prefixCache.find(key);
741+
if (it != prefixCache.end())
742+
return it->second;
743+
744+
assert(level > 0 && "level must be positive");
745+
746+
int64_t previousStride = 1ULL << (level - 1);
747+
if (i < previousStride) {
748+
// No dependency, just copy from the previous level.
749+
auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
750+
prefixCache[key] = {propagateI, generateI};
751+
return prefixCache[key];
752+
}
753+
// Get the dependency index.
754+
int64_t j = i - previousStride;
755+
auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
756+
auto [propagateJ, generateJ] = getGroupAndPropagate(level - 1, j);
757+
// Group generate: g_i OR (p_i AND g_j)
758+
Value andPG = comb::AndOp::create(builder, loc, propagateI, generateJ);
759+
Value newGenerate = comb::OrOp::create(builder, loc, generateI, andPG);
760+
// Group propagate: p_i AND p_j
761+
Value newPropagate =
762+
comb::AndOp::create(builder, loc, propagateI, propagateJ);
763+
prefixCache[key] = {newPropagate, newGenerate};
764+
return prefixCache[key];
765+
}
766+
703767
template <bool lowerToMIG>
704768
struct CombAddOpConversion : OpConversionPattern<AddOp> {
705769
using OpConversionPattern<AddOp>::OpConversionPattern;
@@ -1080,37 +1144,49 @@ struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
10801144
// need the final result. Optimizing this to skip intermediate computations
10811145
// is non-trivial because each iteration depends on results from previous
10821146
// iterations. We rely on DCE passes to remove unused operations.
1083-
// TODO: Lazily compute only the required prefix values.
1147+
// TODO: Lazily compute only the required prefix values. Kogge-Stone is
1148+
// already implemented in a lazy manner below, but other architectures can
1149+
// also be optimized.
10841150
static Value computePrefixComparison(ConversionPatternRewriter &rewriter,
10851151
Location loc, SmallVector<Value> pPrefix,
10861152
SmallVector<Value> gPrefix,
10871153
bool includeEq, AdderArchitecture arch) {
10881154
auto width = pPrefix.size();
1155+
Value finalGroup, finalPropagate;
10891156
// Apply the appropriate prefix tree algorithm
10901157
switch (arch) {
10911158
case AdderArchitecture::RippleCarry:
10921159
llvm_unreachable("Ripple-Carry should be handled separately");
10931160
break;
1094-
case AdderArchitecture::Sklanskey:
1161+
case AdderArchitecture::Sklanskey: {
10951162
lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix);
1163+
finalGroup = gPrefix[width - 1];
1164+
finalPropagate = pPrefix[width - 1];
10961165
break;
1166+
}
10971167
case AdderArchitecture::KoggeStone:
1098-
lowerKoggeStonePrefixTree(rewriter, loc, pPrefix, gPrefix);
1168+
// Use lazy Kogge-Stone implementation to avoid computing all
1169+
// intermediate prefix values.
1170+
std::tie(finalPropagate, finalGroup) =
1171+
LazyKoggeStonePrefixTree(rewriter, loc, width, pPrefix, gPrefix)
1172+
.getFinal(width - 1);
10991173
break;
1100-
case AdderArchitecture::BrentKung:
1174+
case AdderArchitecture::BrentKung: {
11011175
lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix);
1176+
finalGroup = gPrefix[width - 1];
1177+
finalPropagate = pPrefix[width - 1];
11021178
break;
11031179
}
1180+
}
11041181

11051182
// Final result: gPrefix[width-1] gives us "a < b"
11061183
if (includeEq) {
11071184
// a <= b iff (a < b) OR (a == b)
11081185
// a == b iff pPrefix[width-1] (all bits are equal)
1109-
return comb::OrOp::create(rewriter, loc, gPrefix[width - 1],
1110-
pPrefix[width - 1]);
1186+
return comb::OrOp::create(rewriter, loc, finalGroup, finalPropagate);
11111187
}
11121188
// a < b iff gPrefix[width-1]
1113-
return gPrefix[width - 1];
1189+
return finalGroup;
11141190
}
11151191

11161192
// Construct an unsigned comparator using either ripple-carry or

0 commit comments

Comments
 (0)