diff --git a/xls/dslx/type_system_v2/BUILD b/xls/dslx/type_system_v2/BUILD index a8dab19523..7866847a42 100644 --- a/xls/dslx/type_system_v2/BUILD +++ b/xls/dslx/type_system_v2/BUILD @@ -586,7 +586,9 @@ cc_library( "//xls/dslx/type_system:type_info", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", ], ) diff --git a/xls/dslx/type_system_v2/evaluator.cc b/xls/dslx/type_system_v2/evaluator.cc index 182a6b21c9..162ca663c2 100644 --- a/xls/dslx/type_system_v2/evaluator.cc +++ b/xls/dslx/type_system_v2/evaluator.cc @@ -15,15 +15,19 @@ #include "xls/dslx/type_system_v2/evaluator.h" #include +#include #include #include #include +#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "xls/common/casts.h" #include "xls/common/status/status_macros.h" #include "xls/dslx/constexpr_evaluator.h" @@ -42,6 +46,18 @@ namespace xls::dslx { namespace { +template +class ToSigned { + public: + typedef std::make_signed_t type; +}; + +template <> +class ToSigned { + public: + typedef bool type; +}; + class EvaluatorImpl : public Evaluator { public: EvaluatorImpl(InferenceTable& table, Module& module, ImportData& import_data, @@ -57,43 +73,22 @@ class EvaluatorImpl : public Evaluator { absl::StatusOr EvaluateBoolOrExpr( std::optional parametric_context, std::variant value_or_expr) override { - if (std::holds_alternative(value_or_expr)) { - return std::get(value_or_expr); - } - const Expr* expr = std::get(value_or_expr); - std::optional value = - FastEvaluate(parametric_context, BuiltinType::kBool, expr); - - if (!value.has_value()) { - XLS_RETURN_IF_ERROR( - converter_.ConvertSubtree(expr, std::nullopt, parametric_context)); - - XLS_ASSIGN_OR_RETURN( - value, - Evaluate(ParametricContextScopedExpr( - parametric_context, - CreateBoolAnnotation(*expr->owner(), expr->span()), expr))); - - if (expr->kind() == AstNodeKind::kNumber) { - literal_cache_.emplace( - std::make_pair(BuiltinType::kBool, expr->ToString()), *value); - } - } - return value->GetBitValueUnsigned(); + return EvaluateValueOrExpr(parametric_context, + value_or_expr); } - absl::StatusOr EvaluateU32OrExpr( + absl::StatusOr EvaluateU32OrExpr( std::optional parametric_context, std::variant value_or_expr) override { - return Evaluate32BitIntOrExpr(parametric_context, value_or_expr, - /*is_signed=*/false); + return EvaluateValueOrExpr(parametric_context, + value_or_expr); } - absl::StatusOr EvaluateS32OrExpr( + absl::StatusOr EvaluateS32OrExpr( std::optional parametric_context, std::variant value_or_expr) override { - return Evaluate32BitIntOrExpr(parametric_context, value_or_expr, - /*is_signed=*/true); + return EvaluateValueOrExpr(parametric_context, + value_or_expr); } absl::StatusOr Evaluate( @@ -156,17 +151,31 @@ class EvaluatorImpl : public Evaluator { } private: - absl::StatusOr Evaluate32BitIntOrExpr( + template + absl::StatusOr CheckedCast(ValueT value) { + if (value >= static_cast::type>( + std::numeric_limits::min()) && + value <= std::numeric_limits::max()) { + return static_cast(value); + } + // We expect overflows from actual user code be detected at type + // unification, so this is not supposed to be reachable. + return absl::InternalError( + absl::StrCat("Evaluator overflow detected: `", value, + "` cannot be represented in target type.")); + } + + template + absl::StatusOr EvaluateValueOrExpr( std::optional parametric_context, - std::variant value_or_expr, bool is_signed) { - if (std::holds_alternative(value_or_expr)) { - return std::get(value_or_expr); + std::variant value_or_expr) { + if (ValueT* value = std::get_if(&value_or_expr)) { + return CheckedCast(*value); } const Expr* expr = std::get(value_or_expr); - const BuiltinType type = is_signed ? BuiltinType::kS32 : BuiltinType::kU32; std::optional value = - FastEvaluate(parametric_context, type, expr); + FastEvaluate(parametric_context, BuiltinT, expr); if (!value.has_value()) { XLS_RETURN_IF_ERROR(converter_.ConvertSubtree( @@ -175,26 +184,28 @@ class EvaluatorImpl : public Evaluator { std::optional type_annotation = table_.GetTypeAnnotation(expr); if (!type_annotation.has_value()) { - type_annotation = - is_signed ? CreateS32Annotation(*expr->owner(), expr->span()) - : CreateU32Annotation(*expr->owner(), expr->span()); + type_annotation = expr->owner()->Make( + expr->span(), BuiltinT, + expr->owner()->GetOrCreateBuiltinNameDef(BuiltinT)); } XLS_ASSIGN_OR_RETURN( value, Evaluate(ParametricContextScopedExpr(parametric_context, *type_annotation, expr))); if (expr->kind() == AstNodeKind::kNumber) { - literal_cache_.emplace(std::make_pair(type, expr->ToString()), *value); + literal_cache_.emplace(std::make_pair(BuiltinT, expr->ToString()), + *value); } } - int64_t result; - if (value->IsSigned()) { - XLS_ASSIGN_OR_RETURN(result, value->GetBitValueSigned()); + XLS_ASSIGN_OR_RETURN(bool signedness, GetBuiltinTypeSignedness(BuiltinT)); + if (signedness) { + XLS_ASSIGN_OR_RETURN(int64_t result, value->GetBitValueSigned()); + return CheckedCast(result); } else { - XLS_ASSIGN_OR_RETURN(result, value->GetBitValueUnsigned()); + XLS_ASSIGN_OR_RETURN(uint64_t result, value->GetBitValueUnsigned()); + return CheckedCast(result); } - return result; } // Evaluates the given `Expr` if there is a faster way to do so than using diff --git a/xls/dslx/type_system_v2/evaluator.h b/xls/dslx/type_system_v2/evaluator.h index 55ff2e3b4c..ad1c41790f 100644 --- a/xls/dslx/type_system_v2/evaluator.h +++ b/xls/dslx/type_system_v2/evaluator.h @@ -44,14 +44,16 @@ class Evaluator { std::variant value_or_expr) = 0; // Returns the unsigned 32-bit value of `value_or_expr` if it holds a value; - // otherwise, evaluates it in the given parametric context. - virtual absl::StatusOr EvaluateU32OrExpr( + // otherwise, evaluates it in the given parametric context. Returns an error + // if it is not a constexpr, or if the result cannot fit. + virtual absl::StatusOr EvaluateU32OrExpr( std::optional parametric_context, std::variant value_or_expr) = 0; // Returns the signed 32-bit value of `value_or_expr` if it holds a value; - // otherwise, evaluates it in the given parametric context. - virtual absl::StatusOr EvaluateS32OrExpr( + // otherwise, evaluates it in the given parametric context. Returns an error + // if it is not a constexpr, or if the result cannot fit. + virtual absl::StatusOr EvaluateS32OrExpr( std::optional parametric_context, std::variant value_or_expr) = 0; diff --git a/xls/dslx/type_system_v2/inference_table_converter_impl.cc b/xls/dslx/type_system_v2/inference_table_converter_impl.cc index 85af1894ae..cb6192daae 100644 --- a/xls/dslx/type_system_v2/inference_table_converter_impl.cc +++ b/xls/dslx/type_system_v2/inference_table_converter_impl.cc @@ -1193,8 +1193,9 @@ class InferenceTableConverterImpl : public InferenceTableConverter, return std::make_unique(std::move(member_types)); } if (const auto* array = CastToNonBitsArrayTypeAnnotation(annotation)) { - XLS_ASSIGN_OR_RETURN(int64_t size, evaluator_->EvaluateU32OrExpr( - parametric_context, array->dim())); + XLS_ASSIGN_OR_RETURN( + uint32_t size, + evaluator_->EvaluateU32OrExpr(parametric_context, array->dim())); XLS_ASSIGN_OR_RETURN( std::unique_ptr element_type, Concretize(array->element_type(), parametric_context)); @@ -1224,8 +1225,8 @@ class InferenceTableConverterImpl : public InferenceTableConverter, std::move(payload_type), channel->direction()); if (channel->dims().has_value()) { for (Expr* dim : *channel->dims()) { - XLS_ASSIGN_OR_RETURN(int64_t size, evaluator_->EvaluateU32OrExpr( - parametric_context, dim)); + XLS_ASSIGN_OR_RETURN(uint32_t size, evaluator_->EvaluateU32OrExpr( + parametric_context, dim)); type = std::make_unique( std::move(type), TypeDim(InterpValue::MakeU32(size))); } diff --git a/xls/dslx/type_system_v2/type_annotation_resolver.cc b/xls/dslx/type_system_v2/type_annotation_resolver.cc index 158bfc315d..b5bbc13702 100644 --- a/xls/dslx/type_system_v2/type_annotation_resolver.cc +++ b/xls/dslx/type_system_v2/type_annotation_resolver.cc @@ -745,7 +745,7 @@ class StatefulResolver : public TypeAnnotationResolver { const auto* width_slice = std::get(slice_type->slice()); StartAndWidthExprs start_and_width; - absl::StatusOr constexpr_start = + absl::StatusOr constexpr_start = evaluator_.EvaluateU32OrExpr(parametric_context, width_slice->start()); if (constexpr_start.ok()) { start_and_width.start = *constexpr_start; @@ -780,15 +780,8 @@ class StatefulResolver : public TypeAnnotationResolver { } if (constexpr_start.ok()) { - // If start index is a signed value or a negative number literal it would - // have a signed type annotation which contradicts with the type - // annotation of a widthslice `uN[width]` and it would have been caught - // earlier at the unification of the index itself, so start index is - // expected to be unsigned, and the only reason that constexpr_start may - // be negative is that the value being evaluated is a uint64_t with MSB - // set, which overflows when casted to int64_t. It is obvious that a start - // index of 2^63 or greater is always out of range. - if (*constexpr_start < 0 || *constexpr_start + width > source_size) { + if (*constexpr_start > source_size || + width > source_size - *constexpr_start) { // In v2, if the start happens to be constexpr and makes the width too // far, there is an added warning that is not in v1. warning_collector_.Add(