Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -842,8 +842,8 @@ xls_pass(
hdrs = ["strength_reduction_pass.h"],
pass_class = "StrengthReductionPass",
deps = [
":lazy_ternary_query_engine",
":optimization_pass",
":partial_info_query_engine",
":pass_base",
":query_engine",
":stateless_query_engine",
Expand Down
68 changes: 50 additions & 18 deletions xls/passes/strength_reduction_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
#include "xls/ir/op.h"
#include "xls/ir/ternary.h"
#include "xls/ir/value.h"
#include "xls/passes/lazy_ternary_query_engine.h"
#include "xls/passes/optimization_pass.h"
#include "xls/passes/partial_info_query_engine.h"
#include "xls/passes/pass_base.h"
#include "xls/passes/query_engine.h"
#include "xls/passes/stateless_query_engine.h"
Expand All @@ -64,21 +64,57 @@ absl::StatusOr<bool> MaybeSinkOperationIntoSelect(
std::distance(operands.begin(), absl::c_find(operands, select_val));
XLS_RET_CHECK_NE(argument_idx, operands.size())
<< select_val->ToString() << " is not an argument of " << node;
// We need both an unknown select and all other operands to be fully known.
// We need both an unknown select and all other operands to be fully known (or
// for the operation to otherwise simplify through constant-folding).
bool non_select_operands_are_constant = absl::c_all_of(
operands,
[&](Node* n) { return n == select_val || query_engine.IsFullyKnown(n); });
// We don't want to make the select mux wider unless we are pretty sure that
// benefit is worth it.
static constexpr std::array<Op, 14> kExpensiveOps{
Op::kAdd, Op::kSub, Op::kSDiv, Op::kUDiv, Op::kUMod, Op::kSMod, Op::kUMul,
Op::kSMul, Op::kUMulp, Op::kSMulp, Op::kShll, Op::kShra, Op::kShrl,
// Encode of a non-constant is quite slow.
Op::kEncode};
bool sunk_operation_would_simplify = non_select_operands_are_constant;
switch (node->op()) {
case Op::kArrayIndex:
sunk_operation_would_simplify |=
node->As<ArrayIndex>()->indices().size() == 1 &&
node->As<ArrayIndex>()->indices()[0] == select_val;
break;
case Op::kArraySlice:
sunk_operation_would_simplify |=
node->As<ArraySlice>()->start() == select_val;
break;
case Op::kBitSliceUpdate:
sunk_operation_would_simplify |=
node->As<BitSliceUpdate>()->start() == select_val;
break;
case Op::kDynamicBitSlice:
sunk_operation_would_simplify |=
node->As<DynamicBitSlice>()->start() == select_val;
break;
case Op::kSel:
sunk_operation_would_simplify |=
node->As<Select>()->selector() == select_val;
break;
case Op::kPrioritySel:
sunk_operation_would_simplify |=
node->As<PrioritySelect>()->selector() == select_val;
break;
case Op::kShll:
case Op::kShra:
case Op::kShrl: {
// All shifts simplify to wires if the shift amount is known.
sunk_operation_would_simplify |= node->operand(1) == select_val;
break;
}
default:
break;
}
// If there are no significant savings, we don't want to make the select mux
// wider.
static constexpr std::array<Op, 6> kCheapWideningOps{
Op::kArray, Op::kArrayConcat, Op::kConcat,
Op::kSignExt, Op::kTuple, Op::kZeroExt};
bool sink_would_improve_ir = node->GetType()->GetFlatBitCount() <=
select_val->GetType()->GetFlatBitCount() ||
node->OpIn(kExpensiveOps);
if (non_select_operands_are_constant && sink_would_improve_ir) {
!node->OpIn(kCheapWideningOps);
if (sunk_operation_would_simplify && sink_would_improve_ir) {
std::vector<Node*> new_cases;
new_cases.reserve(select_val->cases().size());
std::optional<Node*> new_default;
Expand Down Expand Up @@ -111,12 +147,6 @@ absl::StatusOr<bool> MaybeSinkOperationIntoSelect(
absl::StatusOr<bool> StrengthReduceNode(Node* node,
const QueryEngine& query_engine,
int64_t opt_level) {
if (!std::all_of(node->operands().begin(), node->operands().end(),
[](Node* n) { return n->GetType()->IsBits(); }) ||
!node->GetType()->IsBits()) {
return false;
}

if (NarrowingEnabled(opt_level) &&
// Don't replace unused nodes. We don't want to add nodes when they will
// get DCEd later. This can lead to an infinite loop between strength
Expand Down Expand Up @@ -372,6 +402,7 @@ absl::StatusOr<bool> StrengthReduceNode(Node* node,
// 0 | 0 1
// x 1 | 1 0
if ((node->op() == Op::kAdd || node->op() == Op::kNe) &&
node->operand(0)->GetType()->IsBits() &&
node->operand(0)->BitCountOrDie() == 1) {
XLS_RETURN_IF_ERROR(
node->ReplaceUsesWithNew<NaryOp>(
Expand Down Expand Up @@ -419,6 +450,7 @@ absl::StatusOr<bool> StrengthReduceNode(Node* node,
// Eq(x, 0b00) => x_0 == 0 & x_1 == 0 => ~x_0 & ~x_1 => ~(x_0 | x_1)
// where bits(x) <= 2
if (NarrowingEnabled(opt_level) && node->op() == Op::kEq &&
node->operand(0)->GetType()->IsBits() &&
node->operand(0)->BitCountOrDie() == 2 &&
query_engine.IsAllZeros(node->operand(1))) {
FunctionBase* f = node->function_base();
Expand Down Expand Up @@ -743,7 +775,7 @@ absl::StatusOr<bool> StrengthReductionPass::RunOnFunctionBaseInternal(
PassResults* results, OptimizationContext& context) const {
auto query_engine = UnionQueryEngine::Of(
StatelessQueryEngine(),
GetSharedQueryEngine<LazyTernaryQueryEngine>(context, f));
GetSharedQueryEngine<PartialInfoQueryEngine>(context, f));

XLS_RETURN_IF_ERROR(query_engine.Populate(f).status());

Expand Down
90 changes: 89 additions & 1 deletion xls/passes/strength_reduction_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ using ::absl_testing::IsOkAndHolds;
using ::xls::solvers::z3::ScopedVerifyEquivalence;

using ::testing::_;
using ::testing::AnyOf;
using ::testing::Each;
using ::testing::UnorderedElementsAre;
using ::testing::VariantWith;
Expand Down Expand Up @@ -428,7 +429,8 @@ TEST_F(StrengthReductionPassTest, ArithToSelect) {
ASSERT_THAT(Run(f), IsOkAndHolds(true));
// Actual verification of result is done by semantics test.
EXPECT_THAT(f->return_value()->operands(),
Each(m::Select(m::Eq(), {m::Literal(), m::Literal()})));
Each(AnyOf(m::Literal(),
m::Select(m::Eq(), {m::Literal(), m::Literal()}))));
}

TEST_F(StrengthReductionPassTest, ArithToSelectOnlyWithOneBit) {
Expand Down Expand Up @@ -936,6 +938,92 @@ TEST_F(StrengthReductionPassTest, SubNoSplitIfBorrowMustPropagate) {
ASSERT_THAT(Run(f), IsOkAndHolds(false));
}

TEST_F(StrengthReductionPassTest, SinkDynamicBitSliceWithSelectStart) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u32 = p->GetBitsType(32);
BValue data = fb.Param("data", u32);
BValue s = fb.Param("s", p->GetBitsType(1));
BValue start =
fb.Select(s, {fb.Literal(UBits(2, 32)), fb.Literal(UBits(4, 32))});
fb.DynamicBitSlice(data, start, 16);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
ScopedRecordIr sri(p.get());
ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(
f->return_value(),
m::Select(
m::Param("s"),
{m::DynamicBitSlice(m::Param("data"), m::Literal(2), /*width=*/16),
m::DynamicBitSlice(m::Param("data"), m::Literal(4), /*width=*/16)}));
}

TEST_F(StrengthReductionPassTest, SinkShiftWithSelectShiftAmount) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u32 = p->GetBitsType(32);
BValue data = fb.Param("data", u32);
BValue s = fb.Param("s", p->GetBitsType(1));
BValue shift_amt =
fb.Select(s, {fb.Literal(UBits(2, 32)), fb.Literal(UBits(4, 32))});
fb.Shll(data, shift_amt);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
ScopedRecordIr sri(p.get());
ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(
f->return_value(),
m::Select(m::Param("s"), {m::Shll(m::Param("data"), m::Literal(2)),
m::Shll(m::Param("data"), m::Literal(4))}));
}

TEST_F(StrengthReductionPassTest, SinkArrayIndexWithSelectIndex) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u32 = p->GetBitsType(32);
Type* arr_ty = p->GetArrayType(8, u32);
BValue arr = fb.Param("arr", arr_ty);
BValue s = fb.Param("s", p->GetBitsType(1));
BValue index =
fb.Select(s, {fb.Literal(UBits(2, 32)), fb.Literal(UBits(4, 32))});
fb.ArrayIndex(arr, {index});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
ScopedRecordIr sri(p.get());
ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(),
m::Select(m::Param("s"),
{m::ArrayIndex(m::Param("arr"), {m::Literal(2)}),
m::ArrayIndex(m::Param("arr"), {m::Literal(4)})}));
}

TEST_F(StrengthReductionPassTest, SinkSelectWithSelectSelector) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u32 = p->GetBitsType(32);
BValue s = fb.Param("s", p->GetBitsType(1));
BValue case0 = fb.Param("case0", u32);
BValue case1 = fb.Param("case1", u32);
BValue case2 = fb.Param("case2", u32);
BValue case3 = fb.Param("case3", u32);
BValue selector =
fb.Select(s, {fb.Literal(UBits(3, 2)), fb.Literal(UBits(1, 2))});
fb.Select(selector, {case0, case1, case2, case3});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
ScopedRecordIr sri(p.get());
ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(
f->return_value(),
m::Select(
m::Param("s"),
{m::Select(m::Literal(3), {m::Param("case0"), m::Param("case1"),
m::Param("case2"), m::Param("case3")}),
m::Select(m::Literal(1), {m::Param("case0"), m::Param("case1"),
m::Param("case2"), m::Param("case3")})}));
}

void IrFuzzStrengthReduction(FuzzPackageWithArgs fuzz_package_with_args) {
StrengthReductionPass pass;
OptimizationPassChangesOutputs(std::move(fuzz_package_with_args), pass);
Expand Down