Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xls/dslx/type_system_v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,9 @@ cc_library(
"//xls/dslx/type_system:type_info",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)

Expand Down
99 changes: 55 additions & 44 deletions xls/dslx/type_system_v2/evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
#include "xls/dslx/type_system_v2/evaluator.h"

#include <cstdint>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <variant>

#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "xls/common/casts.h"
#include "xls/common/status/status_macros.h"
#include "xls/dslx/constexpr_evaluator.h"
Expand All @@ -42,6 +46,18 @@
namespace xls::dslx {
namespace {

template <typename T>
class ToSigned {
public:
typedef std::make_signed_t<T> type;
};

template <>
class ToSigned<bool> {
public:
typedef bool type;
};

class EvaluatorImpl : public Evaluator {
public:
EvaluatorImpl(InferenceTable& table, Module& module, ImportData& import_data,
Expand All @@ -57,43 +73,22 @@ class EvaluatorImpl : public Evaluator {
absl::StatusOr<bool> EvaluateBoolOrExpr(
std::optional<const ParametricContext*> parametric_context,
std::variant<bool, const Expr*> value_or_expr) override {
if (std::holds_alternative<bool>(value_or_expr)) {
return std::get<bool>(value_or_expr);
}
const Expr* expr = std::get<const Expr*>(value_or_expr);
std::optional<InterpValue> value =
FastEvaluate(parametric_context, BuiltinType::kBool, expr);

if (!value.has_value()) {
XLS_RETURN_IF_ERROR(
converter_.ConvertSubtree(expr, std::nullopt, parametric_context));

XLS_ASSIGN_OR_RETURN(
value,
Evaluate(ParametricContextScopedExpr(
parametric_context,
CreateBoolAnnotation(*expr->owner(), expr->span()), expr)));

if (expr->kind() == AstNodeKind::kNumber) {
literal_cache_.emplace(
std::make_pair(BuiltinType::kBool, expr->ToString()), *value);
}
}
return value->GetBitValueUnsigned();
return EvaluateValueOrExpr<bool, BuiltinType::kBool>(parametric_context,
value_or_expr);
}

absl::StatusOr<int64_t> EvaluateU32OrExpr(
absl::StatusOr<uint32_t> EvaluateU32OrExpr(
std::optional<const ParametricContext*> parametric_context,
std::variant<int64_t, const Expr*> value_or_expr) override {
return Evaluate32BitIntOrExpr(parametric_context, value_or_expr,
/*is_signed=*/false);
return EvaluateValueOrExpr<uint32_t, BuiltinType::kU32>(parametric_context,
value_or_expr);
}

absl::StatusOr<int64_t> EvaluateS32OrExpr(
absl::StatusOr<int32_t> EvaluateS32OrExpr(
std::optional<const ParametricContext*> parametric_context,
std::variant<int64_t, const Expr*> value_or_expr) override {
return Evaluate32BitIntOrExpr(parametric_context, value_or_expr,
/*is_signed=*/true);
return EvaluateValueOrExpr<int32_t, BuiltinType::kS32>(parametric_context,
value_or_expr);
}

absl::StatusOr<InterpValue> Evaluate(
Expand Down Expand Up @@ -156,17 +151,31 @@ class EvaluatorImpl : public Evaluator {
}

private:
absl::StatusOr<int64_t> Evaluate32BitIntOrExpr(
template <typename ResultT, typename ValueT>
absl::StatusOr<ResultT> CheckedCast(ValueT value) {
if (value >= static_cast<typename ToSigned<ResultT>::type>(
std::numeric_limits<ResultT>::min()) &&
value <= std::numeric_limits<ResultT>::max()) {
return static_cast<ResultT>(value);
}
// We expect overflows from actual user code be detected at type
// unification, so this is not supposed to be reachable.
return absl::InternalError(
absl::StrCat("Evaluator overflow detected: `", value,
"` cannot be represented in target type."));
}

template <typename ResultT, BuiltinType BuiltinT, typename ValueT>
absl::StatusOr<ResultT> EvaluateValueOrExpr(
std::optional<const ParametricContext*> parametric_context,
std::variant<int64_t, const Expr*> value_or_expr, bool is_signed) {
if (std::holds_alternative<int64_t>(value_or_expr)) {
return std::get<int64_t>(value_or_expr);
std::variant<ValueT, const Expr*> value_or_expr) {
if (ValueT* value = std::get_if<ValueT>(&value_or_expr)) {
return CheckedCast<ResultT>(*value);
}

const Expr* expr = std::get<const Expr*>(value_or_expr);
const BuiltinType type = is_signed ? BuiltinType::kS32 : BuiltinType::kU32;
std::optional<InterpValue> value =
FastEvaluate(parametric_context, type, expr);
FastEvaluate(parametric_context, BuiltinT, expr);

if (!value.has_value()) {
XLS_RETURN_IF_ERROR(converter_.ConvertSubtree(
Expand All @@ -175,26 +184,28 @@ class EvaluatorImpl : public Evaluator {
std::optional<const TypeAnnotation*> type_annotation =
table_.GetTypeAnnotation(expr);
if (!type_annotation.has_value()) {
type_annotation =
is_signed ? CreateS32Annotation(*expr->owner(), expr->span())
: CreateU32Annotation(*expr->owner(), expr->span());
type_annotation = expr->owner()->Make<BuiltinTypeAnnotation>(
expr->span(), BuiltinT,
expr->owner()->GetOrCreateBuiltinNameDef(BuiltinT));
}
XLS_ASSIGN_OR_RETURN(
value, Evaluate(ParametricContextScopedExpr(parametric_context,
*type_annotation, expr)));

if (expr->kind() == AstNodeKind::kNumber) {
literal_cache_.emplace(std::make_pair(type, expr->ToString()), *value);
literal_cache_.emplace(std::make_pair(BuiltinT, expr->ToString()),
*value);
}
}

int64_t result;
if (value->IsSigned()) {
XLS_ASSIGN_OR_RETURN(result, value->GetBitValueSigned());
XLS_ASSIGN_OR_RETURN(bool signedness, GetBuiltinTypeSignedness(BuiltinT));
if (signedness) {
XLS_ASSIGN_OR_RETURN(int64_t result, value->GetBitValueSigned());
return CheckedCast<ResultT>(result);
} else {
XLS_ASSIGN_OR_RETURN(result, value->GetBitValueUnsigned());
XLS_ASSIGN_OR_RETURN(uint64_t result, value->GetBitValueUnsigned());
return CheckedCast<ResultT>(result);
}
return result;
}

// Evaluates the given `Expr` if there is a faster way to do so than using
Expand Down
10 changes: 6 additions & 4 deletions xls/dslx/type_system_v2/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ class Evaluator {
std::variant<bool, const Expr*> value_or_expr) = 0;

// Returns the unsigned 32-bit value of `value_or_expr` if it holds a value;
// otherwise, evaluates it in the given parametric context.
virtual absl::StatusOr<int64_t> EvaluateU32OrExpr(
// otherwise, evaluates it in the given parametric context. Returns an error
// if it is not a constexpr, or if the result cannot fit.
virtual absl::StatusOr<uint32_t> EvaluateU32OrExpr(
std::optional<const ParametricContext*> parametric_context,
std::variant<int64_t, const Expr*> value_or_expr) = 0;

// Returns the signed 32-bit value of `value_or_expr` if it holds a value;
// otherwise, evaluates it in the given parametric context.
virtual absl::StatusOr<int64_t> EvaluateS32OrExpr(
// otherwise, evaluates it in the given parametric context. Returns an error
// if it is not a constexpr, or if the result cannot fit.
virtual absl::StatusOr<int32_t> EvaluateS32OrExpr(
std::optional<const ParametricContext*> parametric_context,
std::variant<int64_t, const Expr*> value_or_expr) = 0;

Expand Down
9 changes: 5 additions & 4 deletions xls/dslx/type_system_v2/inference_table_converter_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1193,8 +1193,9 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
return std::make_unique<TupleType>(std::move(member_types));
}
if (const auto* array = CastToNonBitsArrayTypeAnnotation(annotation)) {
XLS_ASSIGN_OR_RETURN(int64_t size, evaluator_->EvaluateU32OrExpr(
parametric_context, array->dim()));
XLS_ASSIGN_OR_RETURN(
uint32_t size,
evaluator_->EvaluateU32OrExpr(parametric_context, array->dim()));
XLS_ASSIGN_OR_RETURN(
std::unique_ptr<Type> element_type,
Concretize(array->element_type(), parametric_context));
Expand Down Expand Up @@ -1224,8 +1225,8 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
std::move(payload_type), channel->direction());
if (channel->dims().has_value()) {
for (Expr* dim : *channel->dims()) {
XLS_ASSIGN_OR_RETURN(int64_t size, evaluator_->EvaluateU32OrExpr(
parametric_context, dim));
XLS_ASSIGN_OR_RETURN(uint32_t size, evaluator_->EvaluateU32OrExpr(
parametric_context, dim));
type = std::make_unique<ArrayType>(
std::move(type), TypeDim(InterpValue::MakeU32(size)));
}
Expand Down
13 changes: 3 additions & 10 deletions xls/dslx/type_system_v2/type_annotation_resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ class StatefulResolver : public TypeAnnotationResolver {

const auto* width_slice = std::get<WidthSlice*>(slice_type->slice());
StartAndWidthExprs start_and_width;
absl::StatusOr<int64_t> constexpr_start =
absl::StatusOr<uint32_t> constexpr_start =
evaluator_.EvaluateU32OrExpr(parametric_context, width_slice->start());
if (constexpr_start.ok()) {
start_and_width.start = *constexpr_start;
Expand Down Expand Up @@ -780,15 +780,8 @@ class StatefulResolver : public TypeAnnotationResolver {
}

if (constexpr_start.ok()) {
// If start index is a signed value or a negative number literal it would
// have a signed type annotation which contradicts with the type
// annotation of a widthslice `uN[width]` and it would have been caught
// earlier at the unification of the index itself, so start index is
// expected to be unsigned, and the only reason that constexpr_start may
// be negative is that the value being evaluated is a uint64_t with MSB
// set, which overflows when casted to int64_t. It is obvious that a start
// index of 2^63 or greater is always out of range.
if (*constexpr_start < 0 || *constexpr_start + width > source_size) {
if (*constexpr_start > source_size ||
width > source_size - *constexpr_start) {
// In v2, if the start happens to be constexpr and makes the width too
// far, there is an added warning that is not in v1.
warning_collector_.Add(
Expand Down
Loading