Skip to content

Commit 4d6e9da

Browse files
huangjdcopybara-github
authored andcommitted
Refactor Evaluator.cc to guard against implicit casting of integer type values. This resolves overflow issues with widthslice
PiperOrigin-RevId: 808361484
1 parent 6fe03f9 commit 4d6e9da

File tree

5 files changed

+71
-62
lines changed

5 files changed

+71
-62
lines changed

xls/dslx/type_system_v2/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,9 @@ cc_library(
586586
"//xls/dslx/type_system:type_info",
587587
"@com_google_absl//absl/container:flat_hash_map",
588588
"@com_google_absl//absl/log",
589+
"@com_google_absl//absl/status",
589590
"@com_google_absl//absl/status:statusor",
591+
"@com_google_absl//absl/strings",
590592
],
591593
)
592594

xls/dslx/type_system_v2/evaluator.cc

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@
1515
#include "xls/dslx/type_system_v2/evaluator.h"
1616

1717
#include <cstdint>
18+
#include <limits>
1819
#include <memory>
1920
#include <optional>
2021
#include <string>
22+
#include <type_traits>
2123
#include <utility>
2224
#include <variant>
2325

2426
#include "absl/container/flat_hash_map.h"
2527
#include "absl/log/log.h"
28+
#include "absl/status/status.h"
2629
#include "absl/status/statusor.h"
30+
#include "absl/strings/str_cat.h"
2731
#include "xls/common/casts.h"
2832
#include "xls/common/status/status_macros.h"
2933
#include "xls/dslx/constexpr_evaluator.h"
@@ -42,6 +46,18 @@
4246
namespace xls::dslx {
4347
namespace {
4448

49+
template <typename T>
50+
class ToSigned {
51+
public:
52+
typedef std::make_signed_t<T> type;
53+
};
54+
55+
template <>
56+
class ToSigned<bool> {
57+
public:
58+
typedef bool type;
59+
};
60+
4561
class EvaluatorImpl : public Evaluator {
4662
public:
4763
EvaluatorImpl(InferenceTable& table, Module& module, ImportData& import_data,
@@ -57,43 +73,22 @@ class EvaluatorImpl : public Evaluator {
5773
absl::StatusOr<bool> EvaluateBoolOrExpr(
5874
std::optional<const ParametricContext*> parametric_context,
5975
std::variant<bool, const Expr*> value_or_expr) override {
60-
if (std::holds_alternative<bool>(value_or_expr)) {
61-
return std::get<bool>(value_or_expr);
62-
}
63-
const Expr* expr = std::get<const Expr*>(value_or_expr);
64-
std::optional<InterpValue> value =
65-
FastEvaluate(parametric_context, BuiltinType::kBool, expr);
66-
67-
if (!value.has_value()) {
68-
XLS_RETURN_IF_ERROR(
69-
converter_.ConvertSubtree(expr, std::nullopt, parametric_context));
70-
71-
XLS_ASSIGN_OR_RETURN(
72-
value,
73-
Evaluate(ParametricContextScopedExpr(
74-
parametric_context,
75-
CreateBoolAnnotation(*expr->owner(), expr->span()), expr)));
76-
77-
if (expr->kind() == AstNodeKind::kNumber) {
78-
literal_cache_.emplace(
79-
std::make_pair(BuiltinType::kBool, expr->ToString()), *value);
80-
}
81-
}
82-
return value->GetBitValueUnsigned();
76+
return EvaluateValueOrExpr<bool, BuiltinType::kBool>(parametric_context,
77+
value_or_expr);
8378
}
8479

85-
absl::StatusOr<int64_t> EvaluateU32OrExpr(
80+
absl::StatusOr<uint32_t> EvaluateU32OrExpr(
8681
std::optional<const ParametricContext*> parametric_context,
8782
std::variant<int64_t, const Expr*> value_or_expr) override {
88-
return Evaluate32BitIntOrExpr(parametric_context, value_or_expr,
89-
/*is_signed=*/false);
83+
return EvaluateValueOrExpr<uint32_t, BuiltinType::kU32>(parametric_context,
84+
value_or_expr);
9085
}
9186

92-
absl::StatusOr<int64_t> EvaluateS32OrExpr(
87+
absl::StatusOr<int32_t> EvaluateS32OrExpr(
9388
std::optional<const ParametricContext*> parametric_context,
9489
std::variant<int64_t, const Expr*> value_or_expr) override {
95-
return Evaluate32BitIntOrExpr(parametric_context, value_or_expr,
96-
/*is_signed=*/true);
90+
return EvaluateValueOrExpr<int32_t, BuiltinType::kS32>(parametric_context,
91+
value_or_expr);
9792
}
9893

9994
absl::StatusOr<InterpValue> Evaluate(
@@ -156,17 +151,31 @@ class EvaluatorImpl : public Evaluator {
156151
}
157152

158153
private:
159-
absl::StatusOr<int64_t> Evaluate32BitIntOrExpr(
154+
template <typename ResultT, typename ValueT>
155+
absl::StatusOr<ResultT> CheckedCast(ValueT value) {
156+
if (value >= static_cast<typename ToSigned<ResultT>::type>(
157+
std::numeric_limits<ResultT>::min()) &&
158+
value <= std::numeric_limits<ResultT>::max()) {
159+
return static_cast<ResultT>(value);
160+
}
161+
// We expect overflows from actual user code be detected at type
162+
// unification, so this is not supposed to be reachable.
163+
return absl::InternalError(
164+
absl::StrCat("Evaluator overflow detected: `", value,
165+
"` cannot be represented in target type."));
166+
}
167+
168+
template <typename ResultT, BuiltinType BuiltinT, typename ValueT>
169+
absl::StatusOr<ResultT> EvaluateValueOrExpr(
160170
std::optional<const ParametricContext*> parametric_context,
161-
std::variant<int64_t, const Expr*> value_or_expr, bool is_signed) {
162-
if (std::holds_alternative<int64_t>(value_or_expr)) {
163-
return std::get<int64_t>(value_or_expr);
171+
std::variant<ValueT, const Expr*> value_or_expr) {
172+
if (ValueT* value = std::get_if<ValueT>(&value_or_expr)) {
173+
return CheckedCast<ResultT>(*value);
164174
}
165175

166176
const Expr* expr = std::get<const Expr*>(value_or_expr);
167-
const BuiltinType type = is_signed ? BuiltinType::kS32 : BuiltinType::kU32;
168177
std::optional<InterpValue> value =
169-
FastEvaluate(parametric_context, type, expr);
178+
FastEvaluate(parametric_context, BuiltinT, expr);
170179

171180
if (!value.has_value()) {
172181
XLS_RETURN_IF_ERROR(converter_.ConvertSubtree(
@@ -175,26 +184,28 @@ class EvaluatorImpl : public Evaluator {
175184
std::optional<const TypeAnnotation*> type_annotation =
176185
table_.GetTypeAnnotation(expr);
177186
if (!type_annotation.has_value()) {
178-
type_annotation =
179-
is_signed ? CreateS32Annotation(*expr->owner(), expr->span())
180-
: CreateU32Annotation(*expr->owner(), expr->span());
187+
type_annotation = expr->owner()->Make<BuiltinTypeAnnotation>(
188+
expr->span(), BuiltinT,
189+
expr->owner()->GetOrCreateBuiltinNameDef(BuiltinT));
181190
}
182191
XLS_ASSIGN_OR_RETURN(
183192
value, Evaluate(ParametricContextScopedExpr(parametric_context,
184193
*type_annotation, expr)));
185194

186195
if (expr->kind() == AstNodeKind::kNumber) {
187-
literal_cache_.emplace(std::make_pair(type, expr->ToString()), *value);
196+
literal_cache_.emplace(std::make_pair(BuiltinT, expr->ToString()),
197+
*value);
188198
}
189199
}
190200

191-
int64_t result;
192-
if (value->IsSigned()) {
193-
XLS_ASSIGN_OR_RETURN(result, value->GetBitValueSigned());
201+
XLS_ASSIGN_OR_RETURN(bool signedness, GetBuiltinTypeSignedness(BuiltinT));
202+
if (signedness) {
203+
XLS_ASSIGN_OR_RETURN(int64_t result, value->GetBitValueSigned());
204+
return CheckedCast<ResultT>(result);
194205
} else {
195-
XLS_ASSIGN_OR_RETURN(result, value->GetBitValueUnsigned());
206+
XLS_ASSIGN_OR_RETURN(uint64_t result, value->GetBitValueUnsigned());
207+
return CheckedCast<ResultT>(result);
196208
}
197-
return result;
198209
}
199210

200211
// Evaluates the given `Expr` if there is a faster way to do so than using

xls/dslx/type_system_v2/evaluator.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,16 @@ class Evaluator {
4444
std::variant<bool, const Expr*> value_or_expr) = 0;
4545

4646
// Returns the unsigned 32-bit value of `value_or_expr` if it holds a value;
47-
// otherwise, evaluates it in the given parametric context.
48-
virtual absl::StatusOr<int64_t> EvaluateU32OrExpr(
47+
// otherwise, evaluates it in the given parametric context. Returns an error
48+
// if it is not a constexpr, or if the result cannot fit.
49+
virtual absl::StatusOr<uint32_t> EvaluateU32OrExpr(
4950
std::optional<const ParametricContext*> parametric_context,
5051
std::variant<int64_t, const Expr*> value_or_expr) = 0;
5152

5253
// Returns the signed 32-bit value of `value_or_expr` if it holds a value;
53-
// otherwise, evaluates it in the given parametric context.
54-
virtual absl::StatusOr<int64_t> EvaluateS32OrExpr(
54+
// otherwise, evaluates it in the given parametric context. Returns an error
55+
// if it is not a constexpr, or if the result cannot fit.
56+
virtual absl::StatusOr<int32_t> EvaluateS32OrExpr(
5557
std::optional<const ParametricContext*> parametric_context,
5658
std::variant<int64_t, const Expr*> value_or_expr) = 0;
5759

xls/dslx/type_system_v2/inference_table_converter_impl.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,8 +1193,9 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
11931193
return std::make_unique<TupleType>(std::move(member_types));
11941194
}
11951195
if (const auto* array = CastToNonBitsArrayTypeAnnotation(annotation)) {
1196-
XLS_ASSIGN_OR_RETURN(int64_t size, evaluator_->EvaluateU32OrExpr(
1197-
parametric_context, array->dim()));
1196+
XLS_ASSIGN_OR_RETURN(
1197+
uint32_t size,
1198+
evaluator_->EvaluateU32OrExpr(parametric_context, array->dim()));
11981199
XLS_ASSIGN_OR_RETURN(
11991200
std::unique_ptr<Type> element_type,
12001201
Concretize(array->element_type(), parametric_context));
@@ -1224,8 +1225,8 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
12241225
std::move(payload_type), channel->direction());
12251226
if (channel->dims().has_value()) {
12261227
for (Expr* dim : *channel->dims()) {
1227-
XLS_ASSIGN_OR_RETURN(int64_t size, evaluator_->EvaluateU32OrExpr(
1228-
parametric_context, dim));
1228+
XLS_ASSIGN_OR_RETURN(uint32_t size, evaluator_->EvaluateU32OrExpr(
1229+
parametric_context, dim));
12291230
type = std::make_unique<ArrayType>(
12301231
std::move(type), TypeDim(InterpValue::MakeU32(size)));
12311232
}

xls/dslx/type_system_v2/type_annotation_resolver.cc

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ class StatefulResolver : public TypeAnnotationResolver {
745745

746746
const auto* width_slice = std::get<WidthSlice*>(slice_type->slice());
747747
StartAndWidthExprs start_and_width;
748-
absl::StatusOr<int64_t> constexpr_start =
748+
absl::StatusOr<uint32_t> constexpr_start =
749749
evaluator_.EvaluateU32OrExpr(parametric_context, width_slice->start());
750750
if (constexpr_start.ok()) {
751751
start_and_width.start = *constexpr_start;
@@ -780,15 +780,8 @@ class StatefulResolver : public TypeAnnotationResolver {
780780
}
781781

782782
if (constexpr_start.ok()) {
783-
// If start index is a signed value or a negative number literal it would
784-
// have a signed type annotation which contradicts with the type
785-
// annotation of a widthslice `uN[width]` and it would have been caught
786-
// earlier at the unification of the index itself, so start index is
787-
// expected to be unsigned, and the only reason that constexpr_start may
788-
// be negative is that the value being evaluated is a uint64_t with MSB
789-
// set, which overflows when casted to int64_t. It is obvious that a start
790-
// index of 2^63 or greater is always out of range.
791-
if (*constexpr_start < 0 || *constexpr_start + width > source_size) {
783+
if (*constexpr_start > source_size ||
784+
width > source_size - *constexpr_start) {
792785
// In v2, if the start happens to be constexpr and makes the width too
793786
// far, there is an added warning that is not in v1.
794787
warning_collector_.Add(

0 commit comments

Comments
 (0)