From 1ccac4214f9efd7a8423b78899a473ff16095067 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Fri, 12 Sep 2025 16:29:25 -0700 Subject: [PATCH 1/9] Add ReplaceInvocationsInModule --- xls/dslx/replace_invocations.cc | 505 +++++++++++ xls/dslx/replace_invocations.h | 75 ++ xls/dslx/replace_invocations_test.cc | 1236 ++++++++++++++++++++++++++ 3 files changed, 1816 insertions(+) create mode 100644 xls/dslx/replace_invocations.cc create mode 100644 xls/dslx/replace_invocations.h create mode 100644 xls/dslx/replace_invocations_test.cc diff --git a/xls/dslx/replace_invocations.cc b/xls/dslx/replace_invocations.cc new file mode 100644 index 0000000000..6639ae83e5 --- /dev/null +++ b/xls/dslx/replace_invocations.cc @@ -0,0 +1,505 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/replace_invocations.h" + +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "xls/common/status/ret_check.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/ast_cloner.h" +#include "xls/dslx/frontend/ast_utils.h" +#include "xls/dslx/frontend/module.h" +#include "xls/dslx/import_data.h" +#include "xls/dslx/parse_and_typecheck.h" +#include "xls/dslx/type_system/parametric_env.h" +#include "xls/dslx/type_system/type_info.h" +#include "xls/dslx/type_system/typecheck_module.h" +#include "xls/dslx/warning_collector.h" +#include "xls/ir/bits.h" + +namespace xls::dslx { + +namespace { + +template +struct Overloaded : T... { + using T::operator()...; +}; +template +Overloaded(T...) -> Overloaded; + +bool MatchesCalleeEnv(const InvocationData& data, + const std::optional& want_env) { + if (!want_env.has_value()) { + return true; + } + if (data.env_to_callee_data().empty()) { + return want_env->empty(); + } + for (const auto& kv : data.env_to_callee_data()) { + const InvocationCalleeData& callee_data = kv.second; + if (callee_data.callee_bindings == *want_env) { + return true; + } + } + return false; +} + +const InvocationRewriteRule* FindMatchingRule( + const InvocationData& data, absl::Span rules) { + for (const InvocationRewriteRule& r : rules) { + if (data.callee() == r.from_callee && + MatchesCalleeEnv(data, r.match_callee_env)) { + return &r; + } + } + return nullptr; +} + +absl::flat_hash_map BuildNameDefMap( + const absl::flat_hash_map& old_to_new) { + absl::flat_hash_map name_def_map; + for (const auto& kv : old_to_new) { + if (auto* old_nd = dynamic_cast(kv.first)) { + name_def_map.emplace(old_nd, down_cast(kv.second)); + } + } + return name_def_map; +} + +absl::StatusOr CloneExprIntoModule( + Expr* e, Module* target_module, + const absl::flat_hash_map& old_to_new, + const absl::flat_hash_map& name_def_map) { + auto it = old_to_new.find(e); + if (it != old_to_new.end()) { + return down_cast(it->second); + } + XLS_ASSIGN_OR_RETURN( + auto pairs, CloneAstAndGetAllPairs( + /*root=*/e, + /*target_module=*/std::optional{target_module}, + /*replacer=*/NameRefReplacer(&name_def_map))); + auto it_cloned = pairs.find(e); + XLS_RET_CHECK(it_cloned != pairs.end()); + return down_cast(it_cloned->second); +} + +absl::StatusOr MakeColonRefSubjectFromTypeRef( + TypeRef* type_ref, const Span& inv_span, Module* target_module, + const absl::flat_hash_map& old_to_new) { + using ReturnT = absl::StatusOr; + return absl::visit(Overloaded{ + [&](ColonRef* old_cref) -> ReturnT { + auto it = old_to_new.find(old_cref); + XLS_RET_CHECK(it != old_to_new.end()); + auto* new_cref = down_cast(it->second); + return ColonRef::Subject(new_cref); + }, + [&](EnumDef* old_enum) -> ReturnT { + const NameDef* old_nd = old_enum->name_def(); + auto it = old_to_new.find(old_nd); + XLS_RET_CHECK(it != old_to_new.end()); + auto* new_nd = down_cast(it->second); + NameRef* nr = target_module->Make( + inv_span, new_nd->identifier(), new_nd, + /*in_parens=*/false); + return ColonRef::Subject(nr); + }, + [&](TypeAlias* ta) -> ReturnT { + auto it = old_to_new.find(&ta->name_def()); + XLS_RET_CHECK(it != old_to_new.end()); + auto* new_nd = down_cast(it->second); + NameRef* nr = target_module->Make( + inv_span, new_nd->identifier(), new_nd, + /*in_parens=*/false); + return ColonRef::Subject(nr); + }, + [&](UseTreeEntry* ute) -> ReturnT { + std::optional leaf = ute->GetLeafNameDef(); + XLS_RET_CHECK(leaf.has_value()); + auto it = old_to_new.find(*leaf); + XLS_RET_CHECK(it != old_to_new.end()); + auto* new_nd = down_cast(it->second); + NameRef* nr = target_module->Make( + inv_span, new_nd->identifier(), new_nd, + /*in_parens=*/false); + return ColonRef::Subject(nr); + }, + [&](auto*) -> ReturnT { + return absl::InvalidArgumentError( + "Unsupported enum type reference form"); + }, + }, + type_ref->type_definition()); +} + +absl::StatusOr DealiasTypeDefinition( + TypeRef* tr, TypeInfo& type_info) { + XLS_ASSIGN_OR_RETURN(TypeInfo::TypeSource ts, + type_info.ResolveTypeDefinition(tr->type_definition())); + while (std::holds_alternative(ts.definition)) { + TypeAlias* ta = std::get(ts.definition); + auto* trta2 = dynamic_cast(&ta->type_annotation()); + if (trta2 == nullptr) { + return absl::InvalidArgumentError( + "Unsupported type alias in explicit replacement (non-TypeRef type)"); + } + TypeRef* tr2 = trta2->type_ref(); + XLS_ASSIGN_OR_RETURN( + ts, type_info.ResolveTypeDefinition(tr2->type_definition())); + } + return ts; +} + +absl::StatusOr ResolveEnumMemberName(const EnumDef* enum_def, + const InterpValue& iv, + TypeInfo& type_info) { + XLS_ASSIGN_OR_RETURN(Bits want_bits, iv.GetBits()); + for (const EnumMember& em : enum_def->values()) { + TypeInfo* ti_used = &type_info; + if (em.value->owner() != type_info.module()) { + std::optional imported = + type_info.GetImportedTypeInfo(em.value->owner()); + if (imported.has_value() && *imported != nullptr) { + ti_used = *imported; + } + } + std::optional mv = ti_used->GetConstExprOption(em.value); + if (!mv.has_value()) { + continue; + } + absl::StatusOr mb = mv->GetBits(); + if (mb.ok() && *mb == want_bits) { + return std::string(em.name_def->identifier()); + } + } + return absl::InvalidArgumentError( + "No matching enum member for provided value"); +} + +// Validates that the provided explicit env keys are a subset of the callee's +// parametric bindings, and that all required (non-defaulted) bindings are +// supplied. Returns InvalidArgumentError with messages matching the inline +// checks previously performed in BuildExplicitParametricsFromEnv. +static absl::Status ValidateExplicitEnvAgainstCallee( + const absl::flat_hash_map& env_map, + const Function* to_callee) { + absl::btree_set callee_keys; + for (const ParametricBinding* pb : to_callee->parametric_bindings()) { + callee_keys.insert(pb->identifier()); + } + + std::vector unknown_keys; + unknown_keys.reserve(env_map.size()); + for (const auto& kv : env_map) { + if (!callee_keys.contains(kv.first)) { + unknown_keys.push_back(kv.first); + } + } + if (!unknown_keys.empty()) { + std::string listed; + for (size_t i = 0; i < unknown_keys.size(); ++i) { + absl::StrAppend(&listed, (i == 0 ? "" : ", "), "`", unknown_keys[i], "`"); + } + return absl::InvalidArgumentError( + absl::StrCat("Unknown binding(s) ", listed, " for replacement callee `", + to_callee->name_def()->identifier(), "`")); + } + + for (const ParametricBinding* pb : to_callee->parametric_bindings()) { + if (!pb->expr() && !env_map.contains(pb->identifier())) { + return absl::InvalidArgumentError( + absl::StrCat("Missing required binding `", pb->identifier(), + "` for replacement callee")); + } + } + return absl::OkStatus(); +} + +// Builds an expression representing the enum member corresponding to the +// provided InterpValue for the given enum TypeRef. +absl::StatusOr BuildEnumParametricExpr( + TypeRef* tr, const EnumDef* enum_def, const InterpValue& iv, + TypeInfo& type_info, const Span& inv_span, Module* target_module, + const absl::flat_hash_map& old_to_new) { + XLS_ASSIGN_OR_RETURN(std::string member_name, + ResolveEnumMemberName(enum_def, iv, type_info)); + XLS_ASSIGN_OR_RETURN( + ColonRef::Subject subject, + MakeColonRefSubjectFromTypeRef(tr, inv_span, target_module, old_to_new)); + ColonRef* cref = target_module->Make(inv_span, subject, member_name, + /*in_parens=*/false); + return static_cast(cref); +} + +// Dispatch point for future non-enum TypeRef support. Currently only enums are +// supported; other kinds will return an error via the enum helper. +absl::StatusOr BuildParametricExprForTypeRef( + TypeRef* tr, const InterpValue& iv, TypeInfo& type_info, + const Span& inv_span, Module* target_module, + const absl::flat_hash_map& old_to_new) { + XLS_ASSIGN_OR_RETURN(TypeInfo::TypeSource ts_final, + DealiasTypeDefinition(tr, type_info)); + if (std::holds_alternative(ts_final.definition)) { + auto* enum_def = std::get(ts_final.definition); + return BuildEnumParametricExpr(tr, enum_def, iv, type_info, inv_span, + target_module, old_to_new); + } + return absl::InvalidArgumentError( + "Unsupported parametric TypeRef in explicit replacement (only enums " + "supported at this time)"); +} + +absl::StatusOr> BuildExplicitParametricsFromEnv( + const ParametricEnv& env, const Function* to_callee, TypeInfo& type_info, + const Span& inv_span, Module* target_module, + const absl::flat_hash_map& old_to_new) { + std::vector result; + absl::flat_hash_map env_map = env.ToMap(); + XLS_RETURN_IF_ERROR(ValidateExplicitEnvAgainstCallee(env_map, to_callee)); + for (const ParametricBinding* pb : to_callee->parametric_bindings()) { + auto it = env_map.find(pb->identifier()); + if (it == env_map.end()) { + continue; + } + const InterpValue& iv = it->second; + TypeAnnotation* ann = pb->type_annotation(); + if (ann == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "Parametric binding `", pb->identifier(), + "` lacks a type annotation; explicit replacement not supported")); + } + if (auto* bta = dynamic_cast(ann)) { + if (!iv.IsBits()) { + return absl::InvalidArgumentError(absl::StrCat( + "Parametric `", pb->identifier(), + "` expected bits value for builtin type ", bta->ToString())); + } + std::string digits = iv.ToString(/*humanize=*/true); + BuiltinNameDef* bnd = + target_module->GetOrCreateBuiltinNameDef(bta->builtin_type()); + BuiltinTypeAnnotation* typed = target_module->Make( + inv_span, bta->builtin_type(), bnd); + Number* num = target_module->Make(inv_span, digits, + NumberKind::kOther, typed, + /*in_parens=*/false); + result.push_back(static_cast(num)); + continue; + } + if (auto* trta = dynamic_cast(ann)) { + TypeRef* tr = trta->type_ref(); + XLS_ASSIGN_OR_RETURN(Expr * enum_expr, BuildParametricExprForTypeRef( + tr, iv, type_info, inv_span, + target_module, old_to_new)); + result.push_back(enum_expr); + continue; + } + return absl::InvalidArgumentError( + absl::StrCat("Unsupported parametric type annotation for `", + pb->identifier(), "` in explicit replacement")); + } + return result; +} + +absl::StatusOr> RetainExplicitParametrics( + const Invocation& inv, Module* target_module, + const absl::flat_hash_map& old_to_new, + const absl::flat_hash_map& name_def_map) { + std::vector new_parametrics; + for (const ExprOrType& eot : inv.explicit_parametrics()) { + if (std::holds_alternative(eot)) { + XLS_ASSIGN_OR_RETURN( + Expr * e, CloneExprIntoModule(std::get(eot), target_module, + old_to_new, name_def_map)); + new_parametrics.push_back(e); + } else { + TypeAnnotation* ta = std::get(eot); + auto it = old_to_new.find(ta); + if (it != old_to_new.end()) { + new_parametrics.push_back(down_cast(it->second)); + continue; + } + XLS_ASSIGN_OR_RETURN( + auto pairs, + CloneAstAndGetAllPairs( + /*root=*/ta, + /*target_module=*/std::optional{target_module}, + /*replacer=*/NameRefReplacer(&name_def_map))); + auto it_cloned = pairs.find(ta); + XLS_RET_CHECK(it_cloned != pairs.end()); + new_parametrics.push_back(down_cast(it_cloned->second)); + } + } + return new_parametrics; +} + +absl::StatusOr TypecheckAndInstallCloned( + std::unique_ptr cloned, const TypecheckedModule& tm, + ImportData& import_data, std::string_view install_subject) { + WarningCollector warnings(import_data.enabled_warnings()); + XLS_ASSIGN_OR_RETURN( + std::unique_ptr module_info, + TypecheckModule(std::move(cloned), tm.module->fs_path().value(), + &import_data, &warnings)); + XLS_ASSIGN_OR_RETURN(ImportTokens subject, + ImportTokens::FromString(install_subject)); + XLS_ASSIGN_OR_RETURN(ModuleInfo * stored, + import_data.Put(subject, std::move(module_info))); + return TypecheckedModule{ + .module = &stored->module(), + .type_info = stored->type_info(), + .warnings = std::move(warnings), + }; +} + +} // namespace + +absl::StatusOr ReplaceInvocationsInModule( + const TypecheckedModule& tm, absl::Span callers, + absl::Span rules, ImportData& import_data, + std::string_view install_subject) { + const Module& module = *tm.module; + TypeInfo& type_info = *tm.type_info; + XLS_RET_CHECK(!callers.empty()); + XLS_RET_CHECK(!rules.empty()); + for (const Function* f : callers) { + XLS_RET_CHECK_NE(f, nullptr); + XLS_RET_CHECK_EQ(f->owner(), &module); + } + for (const InvocationRewriteRule& r : rules) { + XLS_RET_CHECK_NE(r.from_callee, nullptr); + XLS_RET_CHECK_NE(r.to_callee, nullptr); + XLS_RET_CHECK_EQ(r.from_callee->owner(), &module); + XLS_RET_CHECK_EQ(r.to_callee->owner(), &module); + if (r.match_callee_env.has_value()) { + const ParametricEnv& me = *r.match_callee_env; + const auto& pbs = r.from_callee->parametric_bindings(); + absl::btree_set want_keys = me.GetKeySet(); + absl::btree_set callee_keys; + for (const ParametricBinding* pb : pbs) { + callee_keys.insert(pb->identifier()); + } + if (want_keys != callee_keys) { + return absl::InvalidArgumentError( + "match_callee_env keys do not match callee parametric names"); + } + } + } + + CloneReplacer replacer = + [&type_info, callers, rules]( + const AstNode* node, Module* target_module, + const absl::flat_hash_map& old_to_new) + -> absl::StatusOr> { + const Invocation* inv = dynamic_cast(node); + if (inv == nullptr) { + return std::nullopt; + } + auto is_within_any_caller = [&]() -> bool { + for (const Function* c : callers) { + if (ContainedWithinFunction(*inv, *c)) { + return true; + } + } + return false; + }(); + if (!is_within_any_caller) { + return std::nullopt; + } + + std::optional data_opt = + type_info.GetRootInvocationData(inv); + if (!data_opt.has_value()) { + return std::nullopt; + } + const InvocationData* data = *data_opt; + + const InvocationRewriteRule* matched_rule = FindMatchingRule(*data, rules); + if (matched_rule == nullptr) { + return std::nullopt; + } + + const NameDef* old_target = matched_rule->to_callee->name_def(); + auto it_nd = old_to_new.find(old_target); + XLS_RET_CHECK(it_nd != old_to_new.end()); + auto* new_target = down_cast(it_nd->second); + NameRef* new_callee = target_module->Make( + inv->callee()->span(), new_target->identifier(), new_target, + inv->callee()->in_parens()); + + // Pre-built the name def map to avoid rebuilding it for each invocation. + absl::flat_hash_map name_def_map = + BuildNameDefMap(old_to_new); + + auto clone_expr_into = [&](Expr* e) -> absl::StatusOr { + return CloneExprIntoModule(e, target_module, old_to_new, name_def_map); + }; + + std::vector new_args; + new_args.reserve(inv->args().size()); + for (Expr* arg : inv->args()) { + XLS_ASSIGN_OR_RETURN(Expr * cloned, clone_expr_into(arg)); + new_args.push_back(cloned); + } + + std::vector new_parametrics; + + if (matched_rule->to_callee_env.has_value()) { + if (!matched_rule->to_callee_env->empty()) { + XLS_ASSIGN_OR_RETURN( + new_parametrics, + BuildExplicitParametricsFromEnv( + *matched_rule->to_callee_env, matched_rule->to_callee, + type_info, inv->span(), target_module, old_to_new)); + } + } else { + XLS_ASSIGN_OR_RETURN(new_parametrics, + RetainExplicitParametrics(*inv, target_module, + old_to_new, name_def_map)); + } + + std::optional new_origin = std::nullopt; + Invocation* replacement = target_module->Make( + inv->span(), new_callee, std::move(new_args), + std::move(new_parametrics), inv->in_parens(), new_origin); + return replacement; + }; + + XLS_ASSIGN_OR_RETURN(std::unique_ptr cloned, + CloneModule(module, std::move(replacer))); + XLS_RET_CHECK_OK(VerifyClone(&module, cloned.get(), *module.file_table())); + + return TypecheckAndInstallCloned(std::move(cloned), tm, import_data, + install_subject); +} + +absl::StatusOr ReplaceInvocationsInModule( + const TypecheckedModule& tm, const Function* caller, + const InvocationRewriteRule& rule, ImportData& import_data, + std::string_view install_subject) { + const Function* callers_arr[] = {caller}; + const InvocationRewriteRule rules_arr[] = {rule}; + return ReplaceInvocationsInModule(tm, absl::MakeSpan(callers_arr), + absl::MakeSpan(rules_arr), import_data, + install_subject); +} + +} // namespace xls::dslx diff --git a/xls/dslx/replace_invocations.h b/xls/dslx/replace_invocations.h new file mode 100644 index 0000000000..5320734d7a --- /dev/null +++ b/xls/dslx/replace_invocations.h @@ -0,0 +1,75 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_DSLX_REPLACE_INVOCATIONS_H_ +#define XLS_DSLX_REPLACE_INVOCATIONS_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/import_data.h" +#include "xls/dslx/parse_and_typecheck.h" // For TypecheckedModule +#include "xls/dslx/type_system/parametric_env.h" + +namespace xls::dslx { + +// Describes a rule for rewriting invocations irrespective of caller. +// Used by the bulk API where a set of callers is provided separately. +struct InvocationRewriteRule { + // Replace invocations whose resolved callee is exactly this function. + const Function* from_callee = nullptr; + + // The function that should be used as the new callee. + const Function* to_callee = nullptr; + + // Optional filter: only invocations whose callee-side ParametricEnv equals + // this value will be replaced. If not set, matches all instantiations. + std::optional match_callee_env; + + // Optional explicit env for the replacement callee. If not set, retains the + // original invocation's explicit parametrics. If set to an empty env, emits + // no explicit parametrics (rely on deduction). + std::optional to_callee_env; +}; + +// Returns a cloned module where invocations inside `callers` that resolve to +// a rule's `from_callee` (and optionally match the specified callee-side +// parametric environment) have their callee expression replaced so they invoke +// the corresponding `to_callee` instead. +// +// Bulk variant: applies multiple rewrite rules across a set of caller +// functions. The first rule that matches an invocation is applied. The +// `type_info` is consulted to resolve parametric binding information for each +// invocation. +absl::StatusOr ReplaceInvocationsInModule( + const TypecheckedModule& tm, absl::Span callers, + absl::Span rules, ImportData& import_data, + std::string_view install_subject); + +// Non-bulk convenience overload that delegates to the bulk variant. +absl::StatusOr ReplaceInvocationsInModule( + const TypecheckedModule& tm, const Function* caller, + const InvocationRewriteRule& rule, ImportData& import_data, + std::string_view install_subject); + +} // namespace xls::dslx + +#endif // XLS_DSLX_REPLACE_INVOCATIONS_H_ + + + + diff --git a/xls/dslx/replace_invocations_test.cc b/xls/dslx/replace_invocations_test.cc new file mode 100644 index 0000000000..0fdfe5b5e4 --- /dev/null +++ b/xls/dslx/replace_invocations_test.cc @@ -0,0 +1,1236 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/replace_invocations.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "xls/common/status/matchers.h" +#include "xls/dslx/create_import_data.h" +#include "xls/dslx/default_dslx_stdlib_path.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/ast_utils.h" +#include "xls/dslx/frontend/module.h" +#include "xls/dslx/parse_and_typecheck.h" +#include "xls/dslx/type_system/parametric_env.h" +#include "xls/dslx/type_system/type_info.h" +#include "xls/dslx/virtualizable_file_system.h" +#include "xls/ir/bits.h" + +namespace xls::dslx { +namespace { + +using ::absl_testing::StatusIs; + +struct PT { + std::unique_ptr import_data; + TypecheckedModule tm; +}; + +absl::StatusOr ParseTypecheck(std::string text) { + std::filesystem::path stdlib = std::string(::xls::kDefaultDslxStdlibPath); + auto import_data = std::make_unique(CreateImportData( + stdlib, /*additional_search_paths=*/std::vector{}, + kAllWarningsSet, std::make_unique())); + XLS_ASSIGN_OR_RETURN( + TypecheckedModule tm, + ParseAndTypecheck(text, /*path=*/"test.x", /*module_name=*/"test", + import_data.get())); + return PT{.import_data = std::move(import_data), .tm = std::move(tm)}; +} + +TEST(ReplaceInvocationsTest, NonParametricSimpleReplacement) { + const std::string kText = R"(// test +fn a(x: u32) -> u32 { x + u32:1 } +fn b(x: u32) -> u32 { x + u32:2 } +fn caller(x: u32) -> u32 { b(x) + b(x) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + ASSERT_NE(m->GetFunction("caller"), std::nullopt); + ASSERT_NE(m->GetFunction("a"), std::nullopt); + ASSERT_NE(m->GetFunction("b"), std::nullopt); + Function* caller = m->GetFunction("caller").value(); + Function* a = m->GetFunction("a").value(); + Function* b = m->GetFunction("b").value(); + + InvocationRewriteRule rule; + rule.from_callee = b; + rule.to_callee = a; + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + + ASSERT_NE(new_module->GetFunction("caller"), std::nullopt); + Function* caller_new = new_module->GetFunction("caller").value(); + + int b_uses = 0; + int a_uses = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + std::string callee_s = inv->callee()->ToString(); + if (callee_s == "a") a_uses++; + if (callee_s == "b") b_uses++; + } + EXPECT_EQ(a_uses, 2); + EXPECT_EQ(b_uses, 0); +} + +TEST(ReplaceInvocationsTest, ParametricFilterMatchesOnlyOne) { + const std::string kText = R"(// test +fn id(x: uN[N]) -> uN[N] { x } +fn id2(x: uN[N]) -> uN[N] { x } +fn caller() -> (u8, u16) { + let y8 = id(u8:1); + let y16 = id(u16:2); + (y8, y16) +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + ASSERT_NE(m->GetFunction("caller"), std::nullopt); + ASSERT_NE(m->GetFunction("id"), std::nullopt); + ASSERT_NE(m->GetFunction("id2"), std::nullopt); + Function* caller = m->GetFunction("caller").value(); + Function* id = m->GetFunction("id").value(); + Function* id2 = m->GetFunction("id2").value(); + + InvocationRewriteRule rule; + rule.from_callee = id; + rule.to_callee = id2; + rule.match_callee_env = + ParametricEnv(absl::flat_hash_map{ + {"N", InterpValue::MakeUBits(32, 8)}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + + ASSERT_NE(new_module->GetFunction("caller"), std::nullopt); + Function* caller_new = new_module->GetFunction("caller").value(); + + int id_uses = 0; + int id2_uses = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + std::string callee_s = inv->callee()->ToString(); + if (callee_s == "id") id_uses++; + if (callee_s == "id2") id2_uses++; + } + EXPECT_EQ(id2_uses, 1); + EXPECT_EQ(id_uses, 1); +} + +TEST(ReplaceInvocationsTest, + ParametricReplacementNoToEnvRetainsExplicitParams) { + const std::string kText = R"(// test +fn id(x: uN[N]) -> uN[N] { x } +fn id2(x: uN[N]) -> uN[N] { x } +fn caller() -> u8 { id(u8:1) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* id = m->GetFunction("id").value(); + Function* id2 = m->GetFunction("id2").value(); + + InvocationRewriteRule rule; + rule.from_callee = id; + rule.to_callee = id2; + rule.match_callee_env = + ParametricEnv(absl::flat_hash_map{ + {"N", InterpValue::MakeUBits(32, 8)}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + + Function* caller_new = new_module->GetFunction("caller").value(); + int num_checked = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + if (inv->callee()->ToString() != "id2") continue; + EXPECT_FALSE(inv->explicit_parametrics().empty()); + num_checked++; + } + EXPECT_EQ(num_checked, 1); +} + +TEST(ReplaceInvocationsTest, + ParametricReplacementEmptyToEnvDropsExplicitParams) { + const std::string kText = R"(// test +fn id(x: uN[N]) -> uN[N] { x } +fn id2(x: uN[N]) -> uN[N] { x } +fn caller() -> u8 { id(u8:1) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* id = m->GetFunction("id").value(); + Function* id2 = m->GetFunction("id2").value(); + + InvocationRewriteRule rule; + rule.from_callee = id; + rule.to_callee = id2; + rule.to_callee_env = ParametricEnv(); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + + Function* caller_new = new_module->GetFunction("caller").value(); + int num_checked = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + if (inv->callee()->ToString() != "id2") continue; + EXPECT_TRUE(inv->explicit_parametrics().empty()); + num_checked++; + } + EXPECT_EQ(num_checked, 1); +} + +TEST(ReplaceInvocationsTest, ParametricReplacementWithDeductionWorks) { + const std::string kText = R"(// test +fn id(x: uN[N]) -> uN[N] { x } +fn id2(x: uN[N]) -> uN[N] { x } +fn caller(x: u32) -> u32 { id(x) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* id = m->GetFunction("id").value(); + Function* id2 = m->GetFunction("id2").value(); + + InvocationRewriteRule rule; + rule.from_callee = id; + rule.to_callee = id2; + rule.match_callee_env = + ParametricEnv(absl::flat_hash_map{ + {"N", InterpValue::MakeUBits(32, 32)}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + int num_id2 = 0; + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + if (inv->callee()->ToString() == "id2") { + EXPECT_TRUE(inv->explicit_parametrics().empty()); + num_id2++; + } + } + EXPECT_EQ(num_id2, 1); +} + +TEST(ReplaceInvocationsTest, EmptyMatchEnvMatchesOnlyNonParamCallee) { + const std::string kText = R"(// test +fn a(x: u32) -> u32 { x + u32:1 } +fn b(x: u32) -> u32 { x + u32:2 } +fn caller(x: u32) -> u32 { b(x) + b(x) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* a = m->GetFunction("a").value(); + Function* b = m->GetFunction("b").value(); + + InvocationRewriteRule rule; + rule.from_callee = b; + rule.to_callee = a; + rule.match_callee_env = ParametricEnv(); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int a_uses = 0; + int b_uses = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + std::string callee_s = inv->callee()->ToString(); + if (callee_s == "a") a_uses++; + if (callee_s == "b") b_uses++; + } + EXPECT_EQ(a_uses, 2); + EXPECT_EQ(b_uses, 0); +} + +TEST(ReplaceInvocationsTest, EmptyMatchEnvDoesNotMatchParametricCallee) { + const std::string kText = R"(// test +fn id(x: uN[N]) -> uN[N] { x } +fn caller() -> (u8, u16) { (id(u8:1), id(u16:2)) } +fn id2(x: uN[N]) -> uN[N] { x } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* id = m->GetFunction("id").value(); + Function* id2 = m->GetFunction("id2").value(); + + InvocationRewriteRule rule; + rule.from_callee = id; + rule.to_callee = id2; + rule.match_callee_env = ParametricEnv(); + + auto status_or = + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw"); + EXPECT_THAT(status_or, StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ReplaceInvocationsTest, ParametricFilterMatchesOnlyOne_Deduced) { + const std::string kText = R"(// test +fn id(x: uN[N]) -> uN[N] { x } +fn id2(x: uN[N]) -> uN[N] { x } +fn caller() -> (u8, u16) { + let a: u8 = u8:1; + let b: u16 = u16:2; + (id(a), id(b)) +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* id = m->GetFunction("id").value(); + Function* id2 = m->GetFunction("id2").value(); + + InvocationRewriteRule rule2; + rule2.from_callee = id; + rule2.to_callee = id2; + rule2.match_callee_env = + ParametricEnv(absl::flat_hash_map{ + {"N", InterpValue::MakeUBits(32, 8)}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule2, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int id_uses = 0; + int id2_uses = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + std::string callee_s = inv->callee()->ToString(); + if (callee_s == "id") id_uses++; + if (callee_s == "id2") id2_uses++; + } + EXPECT_EQ(id2_uses, 1); + EXPECT_EQ(id_uses, 1); +} + +TEST(ReplaceInvocationsTest, ParametricEnumFilterAndExplicitReplacement) { + const std::string kText = R"(// test +enum E : u2 { + A = 0, + B = 1, +} +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller(x: u32) -> (u32, u32) { + let a = f(x); + let b = f(x); + (a, b) +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + XLS_ASSERT_OK_AND_ASSIGN(TypeDefinition td, m->GetTypeDefinition("E")); + EnumDef* e_def = std::get(td); + + InterpValue enum_a = + InterpValue::MakeEnum(xls::UBits(/*value=*/0, /*bit_count=*/2), + /*is_signed=*/false, e_def); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.match_callee_env = ParametricEnv( + absl::flat_hash_map{{"N", enum_a}}); + rule.to_callee_env = rule.match_callee_env; + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int count_f = 0; + int count_g = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + std::string callee_s = inv->callee()->ToString(); + if (callee_s == "f") count_f++; + if (callee_s == "g") { + count_g++; + EXPECT_EQ(inv->explicit_parametrics().size(), 1); + } + } + EXPECT_EQ(count_g, 1); + EXPECT_EQ(count_f, 1); +} + +TEST(ReplaceInvocationsTest, ParametricEnumAliasMatchWorks) { + const std::string kText = R"(// test +enum E : u2 { + A = 1, + B = 1, +} +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller(x: u32) -> u32 { f(x) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + XLS_ASSERT_OK_AND_ASSIGN(TypeDefinition td, m->GetTypeDefinition("E")); + EnumDef* e_def = std::get(td); + InterpValue enum_a = + InterpValue::MakeEnum(xls::UBits(/*value=*/1, /*bit_count=*/2), + /*is_signed=*/false, e_def); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.match_callee_env = ParametricEnv( + absl::flat_hash_map{{"N", enum_a}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int count_f = 0; + int count_g = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + std::string callee_s = inv->callee()->ToString(); + if (callee_s == "f") count_f++; + if (callee_s == "g") { + count_g++; + EXPECT_EQ(inv->explicit_parametrics().size(), 1); + } + } + EXPECT_EQ(count_g, 1); + EXPECT_EQ(count_f, 0); +} + +TEST(ReplaceInvocationsTest, ParametricEnumExplicitReplacementCrossModule) { + const std::string kBase = R"(// base.x +pub enum E : u2 { + A = 0, + B = 1, +} +)"; + const std::string kTest = R"(// test.x +import base; +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller(x: u32) -> u32 { f(x) } +)"; + + const std::string base_path = "/mem/base.x"; + const std::string test_path = "/mem/test.x"; + absl::flat_hash_map files; + files.emplace(base_path, kBase); + files.emplace(test_path, kTest); + auto vfs = std::make_unique(std::move(files), std::filesystem::path("/mem")); + + auto import_data = std::make_unique( + CreateImportDataForTest(std::move(vfs), kAllWarningsSet)); + + XLS_ASSERT_OK_AND_ASSIGN(TypecheckedModule tm, + ParseAndTypecheck(kTest, /*path=*/test_path, + /*module_name=*/"test", + import_data.get())); + Module* m = tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + ASSERT_FALSE(g->parametric_bindings().empty()); + const ParametricBinding* pb = g->parametric_bindings()[0]; + auto* trta = dynamic_cast(pb->type_annotation()); + ASSERT_NE(trta, nullptr); + TypeRef* tr = trta->type_ref(); + XLS_ASSERT_OK_AND_ASSIGN( + TypeInfo::TypeSource ts, + tm.type_info->ResolveTypeDefinition(tr->type_definition())); + auto* enum_def = std::get(ts.definition); + ASSERT_NE(enum_def, nullptr); + InterpValue enum_a = + InterpValue::MakeEnum(xls::UBits(/*value=*/0, /*bit_count=*/2), + /*is_signed=*/false, enum_def); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.to_callee_env = ParametricEnv( + absl::flat_hash_map{{"N", enum_a}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(tm, caller, rule, *import_data, "test.rw")); + (void)new_tm; +} + +TEST(ReplaceInvocationsTest, + ParametricEnumExplicitReplacementCrossModuleQualifiedSubject) { + const std::string kBase = R"(// base.x +pub enum E : u2 { + A = 0, + B = 1, +} +)"; + const std::string kTest = R"(// test.x +import base; +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller(x: u32) -> u32 { f(x) } +)"; + + const std::string base_path = "/mem/base.x"; + const std::string test_path = "/mem/test.x"; + absl::flat_hash_map files; + files.emplace(base_path, kBase); + files.emplace(test_path, kTest); + auto vfs = std::make_unique(std::move(files), std::filesystem::path("/mem")); + + auto import_data = std::make_unique( + CreateImportDataForTest(std::move(vfs), kAllWarningsSet)); + + XLS_ASSERT_OK_AND_ASSIGN(TypecheckedModule tm, + ParseAndTypecheck(kTest, /*path=*/test_path, + /*module_name=*/"test", + import_data.get())); + Module* m = tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + ASSERT_FALSE(g->parametric_bindings().empty()); + const ParametricBinding* pb = g->parametric_bindings()[0]; + auto* trta = dynamic_cast(pb->type_annotation()); + ASSERT_NE(trta, nullptr); + TypeRef* tr = trta->type_ref(); + XLS_ASSERT_OK_AND_ASSIGN(TypeInfo::TypeSource ts, + tm.type_info->ResolveTypeDefinition(tr->type_definition())); + auto* enum_def = std::get(ts.definition); + ASSERT_NE(enum_def, nullptr); + InterpValue enum_a = + InterpValue::MakeEnum(xls::UBits(/*value=*/0, /*bit_count=*/2), + /*is_signed=*/false, enum_def); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.to_callee_env = ParametricEnv( + absl::flat_hash_map{{"N", enum_a}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(tm, caller, rule, *import_data, "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + + int num_checked = 0; + XLS_ASSERT_OK_AND_ASSIGN(auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + if (inv->callee()->ToString() != "g") continue; + EXPECT_EQ(inv->explicit_parametrics().size(), 1); + // The subject should be a qualified colon ref: base::E::A + EXPECT_NE(inv->ToString().find("base::E::A"), std::string::npos); + num_checked++; + } + EXPECT_EQ(num_checked, 1); +} + +TEST(ReplaceInvocationsTest, ReplaceAllParametricWhenNoMatchEnvExplicit) { + const std::string kText = R"(// test +fn id(x: uN[N]) -> uN[N] { x } +fn id2(x: uN[N]) -> uN[N] { x } +fn caller() -> (u8, u16) { (id(u8:1), id(u16:2)) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* id = m->GetFunction("id").value(); + Function* id2 = m->GetFunction("id2").value(); + + InvocationRewriteRule rule; + rule.from_callee = id; + rule.to_callee = id2; + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int id2_uses = 0; + int with_explicit = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + if (inv->callee()->ToString() == "id2") { + id2_uses++; + if (!inv->explicit_parametrics().empty()) with_explicit++; + } + } + EXPECT_EQ(id2_uses, 2); + EXPECT_EQ(with_explicit, 2); +} + +TEST(ReplaceInvocationsTest, ReplaceAllParametricWhenNoMatchEnvDeduced) { + const std::string kText = R"(// test +fn id(x: uN[N]) -> uN[N] { x } +fn id2(x: uN[N]) -> uN[N] { x } +fn caller() -> (u8, u16) { + let a: u8 = u8:1; + let b: u16 = u16:2; + (id(a), id(b)) +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* id = m->GetFunction("id").value(); + Function* id2 = m->GetFunction("id2").value(); + + InvocationRewriteRule rule; + rule.from_callee = id; + rule.to_callee = id2; + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int id2_uses = 0; + int num_with_params = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + if (inv->callee()->ToString() == "id2") { + id2_uses++; + if (!inv->explicit_parametrics().empty()) num_with_params++; + } + } + EXPECT_EQ(id2_uses, 2); + EXPECT_EQ(num_with_params, 0); +} + +TEST(ReplaceInvocationsTest, BitsExplicitParamUnsigned) { + const std::string kText = R"(// test +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller(x: u32) -> u32 { f(x) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.match_callee_env = + ParametricEnv(absl::flat_hash_map{ + {"N", InterpValue::MakeUBits(32, 1)}}); + rule.to_callee_env = + ParametricEnv(absl::flat_hash_map{ + {"N", InterpValue::MakeUBits(32, 8)}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int num_checked = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + if (inv->callee()->ToString() != "g") continue; + EXPECT_EQ(inv->explicit_parametrics().size(), 1); + EXPECT_NE(inv->ToString().find("u32:8"), std::string::npos); + num_checked++; + } + EXPECT_EQ(num_checked, 1); +} + +TEST(ReplaceInvocationsTest, MatchEnvOrderIrrelevant) { + const std::string kText = R"(// test +fn id(x: u32) -> u32 { x } +fn id2(x: u32) -> u32 { x } +fn caller() -> u32 { id(u32:0) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* id = m->GetFunction("id").value(); + Function* id2 = m->GetFunction("id2").value(); + + InvocationRewriteRule rule2; + rule2.from_callee = id; + rule2.to_callee = id2; + rule2.match_callee_env = + ParametricEnv(absl::flat_hash_map{ + {"A", InterpValue::MakeUBits(32, 2)}, + {"B", InterpValue::MakeUBits(32, 1)}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule2, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int id_uses = 0; + int id2_uses = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + std::string s = inv->callee()->ToString(); + if (s == "id") id_uses++; + if (s == "id2") id2_uses++; + } + EXPECT_EQ(id_uses, 0); + EXPECT_EQ(id2_uses, 1); +} + +TEST(ReplaceInvocationsTest, ToEnvMissingRequiredBindingErrors) { + const std::string kText = R"(// test +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller() -> u32 { f(u32:0) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.to_callee_env = + ParametricEnv(absl::flat_hash_map{ + {"K", InterpValue::MakeUBits(32, 1)}}); + + auto status_or = + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw"); + EXPECT_THAT(status_or, StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ReplaceInvocationsTest, ToEnvUnknownBindingErrors) { + const std::string kText = R"(// test +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller() -> u32 { f(u32:0) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.to_callee_env = ParametricEnv( + absl::flat_hash_map{{"M", InterpValue::MakeUBits(32, 1)}, + {"K", InterpValue::MakeUBits(32, 1)}, + {"Z", InterpValue::MakeUBits(32, 5)}}); + + auto status_or = + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw"); + EXPECT_THAT(status_or, StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ReplaceInvocationsTest, EnumToEnvNoMemberErrors) { + const std::string kText = R"(// test +enum E : u2 { + A = 0, + B = 1, +} +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller(x: u32) -> u32 { f(x) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + XLS_ASSERT_OK_AND_ASSIGN(TypeDefinition td, m->GetTypeDefinition("E")); + EnumDef* e_def = std::get(td); + InterpValue enum_bad = + InterpValue::MakeEnum(xls::UBits(/*value=*/2, /*bit_count=*/2), + /*is_signed=*/false, e_def); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.to_callee_env = ParametricEnv( + absl::flat_hash_map{{"N", enum_bad}}); + + auto status_or = + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw"); + EXPECT_THAT(status_or, StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ReplaceInvocationsTest, BulkMultipleCallersMultipleRules) { + const std::string kText = R"(// test +fn f(x: u32) -> u32 { x + u32:1 } +fn g(x: u32) -> u32 { x + u32:2 } +fn h(x: u32) -> u32 { x + u32:3 } +fn caller1(x: u32) -> u32 { f(x) + g(x) } +fn caller2(x: u32) -> u32 { f(x) + h(x) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller1 = m->GetFunction("caller1").value(); + Function* caller2 = m->GetFunction("caller2").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + Function* h = m->GetFunction("h").value(); + + std::vector callers{caller1, caller2}; + std::vector rules; + rules.push_back(InvocationRewriteRule{.from_callee = f, .to_callee = g}); + rules.push_back(InvocationRewriteRule{.from_callee = g, .to_callee = h}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, absl::MakeSpan(callers), + absl::MakeSpan(rules), *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + + auto check_counts = [&](std::string caller_name, int expect_f, int expect_g, + int expect_h) { + Function* c = new_module->GetFunction(caller_name).value(); + int f_uses = 0, g_uses = 0, h_uses = 0; + XLS_ASSERT_OK_AND_ASSIGN(auto nodes, + CollectUnder(c->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + std::string s = inv->callee()->ToString(); + if (s == "f") f_uses++; + if (s == "g") g_uses++; + if (s == "h") h_uses++; + } + EXPECT_EQ(f_uses, expect_f); + EXPECT_EQ(g_uses, expect_g); + EXPECT_EQ(h_uses, expect_h); + }; + + check_counts("caller1", /*f=*/0, /*g=*/1, /*h=*/1); + check_counts("caller2", /*f=*/0, /*g=*/1, /*h=*/1); +} + +TEST(ReplaceInvocationsTest, BulkParametricMatchAcrossCallers) { + const std::string kText = R"(// test +fn id(x: uN[N]) -> uN[N] { x } +fn id2(x: uN[N]) -> uN[N] { x } +fn id3(x: uN[N]) -> uN[N] { x } +fn caller1() -> (u8, u16) { (id(u8:1), id(u16:2)) } +fn caller2() -> (u16, u8) { (id(u16:3), id(u8:4)) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller1 = m->GetFunction("caller1").value(); + Function* caller2 = m->GetFunction("caller2").value(); + Function* id = m->GetFunction("id").value(); + Function* id2 = m->GetFunction("id2").value(); + Function* id3 = m->GetFunction("id3").value(); + + std::vector callers{caller1, caller2}; + std::vector rules; + rules.push_back(InvocationRewriteRule{ + .from_callee = id, + .to_callee = id2, + .match_callee_env = + ParametricEnv(absl::flat_hash_map{ + {"N", InterpValue::MakeUBits(32, 8)}})}); + rules.push_back(InvocationRewriteRule{ + .from_callee = id, + .to_callee = id3, + .match_callee_env = + ParametricEnv(absl::flat_hash_map{ + {"N", InterpValue::MakeUBits(32, 16)}})}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, absl::MakeSpan(callers), + absl::MakeSpan(rules), *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + + auto count_in_caller = + [&](std::string caller_name) -> std::tuple { + Function* c = new_module->GetFunction(caller_name).value(); + int id_uses = 0, id2_uses = 0, id3_uses = 0; + auto nodes_or = CollectUnder(c->body(), /*want_types=*/false); + EXPECT_TRUE(nodes_or.ok()); + auto nodes = nodes_or.value(); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + std::string s = inv->callee()->ToString(); + if (s == "id") id_uses++; + if (s == "id2") id2_uses++; + if (s == "id3") id3_uses++; + } + return std::tuple{id_uses, id2_uses, id3_uses}; + }; + + auto [u1, u2, u3] = count_in_caller("caller1"); + EXPECT_EQ(u1, 0); + EXPECT_EQ(u2, 1); + EXPECT_EQ(u3, 1); + + auto [v1, v2, v3] = count_in_caller("caller2"); + EXPECT_EQ(v1, 0); + EXPECT_EQ(v2, 1); + EXPECT_EQ(v3, 1); +} + +TEST(ReplaceInvocationsTest, ParametricEnumTypeAliasLocalExplicitReplacement) { + const std::string kText = R"(// test +enum E : u2 { + A = 0, + B = 1, +} +type Alias = E; +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller(x: u32) -> u32 { f(x) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + XLS_ASSERT_OK_AND_ASSIGN(TypeDefinition td, m->GetTypeDefinition("E")); + EnumDef* e_def = std::get(td); + InterpValue enum_a = + InterpValue::MakeEnum(xls::UBits(/*value=*/0, /*bit_count=*/2), + /*is_signed=*/false, e_def); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.to_callee_env = ParametricEnv( + absl::flat_hash_map{{"N", enum_a}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int num_checked = 0; + XLS_ASSERT_OK_AND_ASSIGN(auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + if (inv->callee()->ToString() != "g") continue; + EXPECT_EQ(inv->explicit_parametrics().size(), 1); + EXPECT_NE(inv->ToString().find("Alias::A"), std::string::npos); + num_checked++; + } + EXPECT_EQ(num_checked, 1); +} + +TEST(ReplaceInvocationsTest, ParametricEnumTypeAliasCrossModuleExplicitReplacement) { + const std::string kBase = R"(// base.x +pub enum E : u2 { + A = 0, + B = 1, +} +)"; + const std::string kTest = R"(// test.x +import base; +type Alias = base::E; +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller(x: u32) -> u32 { f(x) } +)"; + + const std::string base_path = "/mem/base.x"; + const std::string test_path = "/mem/test.x"; + absl::flat_hash_map files; + files.emplace(base_path, kBase); + files.emplace(test_path, kTest); + auto vfs = std::make_unique(std::move(files), std::filesystem::path("/mem")); + auto import_data = std::make_unique( + CreateImportDataForTest(std::move(vfs), kAllWarningsSet)); + + XLS_ASSERT_OK_AND_ASSIGN(TypecheckedModule tm, + ParseAndTypecheck(kTest, /*path=*/test_path, + /*module_name=*/"test", + import_data.get())); + Module* m = tm.module; + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + ASSERT_FALSE(g->parametric_bindings().empty()); + const ParametricBinding* pb = g->parametric_bindings()[0]; + auto* trta = dynamic_cast(pb->type_annotation()); + ASSERT_NE(trta, nullptr); + TypeRef* tr = trta->type_ref(); + XLS_ASSERT_OK_AND_ASSIGN(TypeInfo::TypeSource ts, + tm.type_info->ResolveTypeDefinition(tr->type_definition())); + EnumDef* enum_def = nullptr; + if (std::holds_alternative(ts.definition)) { + enum_def = std::get(ts.definition); + } else { + auto* ta = std::get(ts.definition); + ASSERT_NE(ta, nullptr); + auto* trta2 = dynamic_cast(&ta->type_annotation()); + ASSERT_NE(trta2, nullptr); + TypeRef* tr2 = trta2->type_ref(); + XLS_ASSERT_OK_AND_ASSIGN(TypeInfo::TypeSource ts2, + tm.type_info->ResolveTypeDefinition(tr2->type_definition())); + enum_def = std::get(ts2.definition); + } + InterpValue enum_a = + InterpValue::MakeEnum(xls::UBits(/*value=*/0, /*bit_count=*/2), + /*is_signed=*/false, enum_def); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.to_callee_env = ParametricEnv( + absl::flat_hash_map{{"N", enum_a}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(tm, caller, rule, *import_data, "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int num_checked = 0; + XLS_ASSERT_OK_AND_ASSIGN(auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + if (inv->callee()->ToString() != "g") continue; + EXPECT_EQ(inv->explicit_parametrics().size(), 1); + EXPECT_NE(inv->ToString().find("Alias::A"), std::string::npos); + num_checked++; + } + EXPECT_EQ(num_checked, 1); +} + +TEST(ReplaceInvocationsTest, ParametricEnumUseImportExplicitReplacement) { + const std::string kBase = R"(// base.x +pub enum E : u2 { + A = 0, + B = 1, +} +)"; + const std::string kTest = R"(// test.x +#![feature(use_syntax)] +use base; +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller(x: u32) -> u32 { f(x) } +)"; + + const std::string base_path = "/mem/base.x"; + const std::string test_path = "/mem/test.x"; + absl::flat_hash_map files; + files.emplace(base_path, kBase); + files.emplace(test_path, kTest); + auto vfs = std::make_unique(std::move(files), std::filesystem::path("/mem")); + auto import_data = std::make_unique( + CreateImportDataForTest(std::move(vfs), kAllWarningsSet)); + + XLS_ASSERT_OK_AND_ASSIGN(TypecheckedModule tm, + ParseAndTypecheck(kTest, /*path=*/test_path, + /*module_name=*/"test", + import_data.get())); + Module* m = tm.module; + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + ASSERT_FALSE(g->parametric_bindings().empty()); + const ParametricBinding* pb = g->parametric_bindings()[0]; + auto* trta = dynamic_cast(pb->type_annotation()); + ASSERT_NE(trta, nullptr); + TypeRef* tr = trta->type_ref(); + XLS_ASSERT_OK_AND_ASSIGN(TypeInfo::TypeSource ts, + tm.type_info->ResolveTypeDefinition(tr->type_definition())); + auto* enum_def = std::get(ts.definition); + ASSERT_NE(enum_def, nullptr); + InterpValue enum_a = + InterpValue::MakeEnum(xls::UBits(/*value=*/0, /*bit_count=*/2), + /*is_signed=*/false, enum_def); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.to_callee_env = ParametricEnv( + absl::flat_hash_map{{"N", enum_a}}); + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(tm, caller, rule, *import_data, "test.rw")); + Module* new_module = new_tm.module; + Function* caller_new = new_module->GetFunction("caller").value(); + int num_checked = 0; + XLS_ASSERT_OK_AND_ASSIGN(auto nodes, CollectUnder(caller_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + auto* inv = dynamic_cast(n); + if (inv == nullptr) continue; + if (inv->callee()->ToString() != "g") continue; + EXPECT_EQ(inv->explicit_parametrics().size(), 1); + EXPECT_NE(inv->ToString().find("E::A"), std::string::npos); + num_checked++; + } + EXPECT_EQ(num_checked, 1); +} + +TEST(ReplaceInvocationsTest, ParametricToEnvNonEnumTypeAnnotationErrors) { + const std::string kText = R"(// test +struct S { + a: u32, +} +fn f(x: u32) -> u32 { x } +fn g(x: u32) -> u32 { x } +fn caller(x: u32) -> u32 { f(x) } +)"; + XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + Function* caller = m->GetFunction("caller").value(); + Function* f = m->GetFunction("f").value(); + Function* g = m->GetFunction("g").value(); + + InvocationRewriteRule rule; + rule.from_callee = f; + rule.to_callee = g; + rule.to_callee_env = ParametricEnv( + absl::flat_hash_map{{"N", InterpValue::MakeUBits(32, 1)}}); + + auto status_or = ReplaceInvocationsInModule(pt.tm, caller, rule, *pt.import_data, + "test.rw"); + EXPECT_THAT(status_or, StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace xls::dslx + + + + From 2f66ec5c00ca73610da19a31314c15f69b8e63a8 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Fri, 12 Sep 2025 20:33:52 -0700 Subject: [PATCH 2/9] Add C API and dslx BUILD file --- xls/dslx/BUILD | 50 +++++++++++++++++++++++++++++ xls/public/BUILD | 1 + xls/public/c_api_dslx.cc | 62 ++++++++++++++++++++++++++++++++++++ xls/public/c_api_dslx.h | 24 ++++++++++++++ xls/public/c_api_symbols.txt | 1 + 5 files changed, 138 insertions(+) diff --git a/xls/dslx/BUILD b/xls/dslx/BUILD index ac1fa876d9..fc59b2d559 100644 --- a/xls/dslx/BUILD +++ b/xls/dslx/BUILD @@ -37,6 +37,32 @@ bzl_library( srcs = ["strip_comments.bzl"], ) +cc_library( + name = "replace_invocations", + srcs = ["replace_invocations.cc"], + hdrs = ["replace_invocations.h"], + deps = [ + ":import_data", + ":parse_and_typecheck", + ":warning_collector", + "//xls/common/status:ret_check", + "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:ast_cloner", + "//xls/dslx/frontend:ast_utils", + "//xls/dslx/frontend:module", + "//xls/dslx/type_system:parametric_env", + "//xls/dslx/type_system:type_info", + "//xls/dslx/type_system:typecheck_module", + "//xls/ir:bits", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "virtualizable_file_system", srcs = ["virtualizable_file_system.cc"], @@ -133,6 +159,30 @@ cc_library( ], ) +cc_test( + name = "replace_invocations_test", + srcs = ["replace_invocations_test.cc"], + deps = [ + ":create_import_data", + ":default_dslx_stdlib_path", + ":parse_and_typecheck", + ":replace_invocations", + ":virtualizable_file_system", + "//xls/common:xls_gunit_main", + "//xls/common/status:matchers", + "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:ast_utils", + "//xls/dslx/frontend:module", + "//xls/dslx/type_system:parametric_env", + "//xls/dslx/type_system:type_info", + "//xls/ir:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/types:span", + "@googletest//:gtest", + ], +) + cc_library( name = "errors", srcs = ["errors.cc"], diff --git a/xls/public/BUILD b/xls/public/BUILD index edf45c5bf1..ee95d69cbd 100644 --- a/xls/public/BUILD +++ b/xls/public/BUILD @@ -213,6 +213,7 @@ cc_library( deps = [ ":c_api_impl_helpers", "//xls/common:visitor", + "//xls/dslx:replace_invocations", "//xls/dslx:create_import_data", "//xls/dslx:import_data", "//xls/dslx:interp_value", diff --git a/xls/public/c_api_dslx.cc b/xls/public/c_api_dslx.cc index e8198aa742..8077d89f87 100644 --- a/xls/public/c_api_dslx.cc +++ b/xls/public/c_api_dslx.cc @@ -37,6 +37,7 @@ #include "xls/dslx/interp_value.h" #include "xls/dslx/interp_value_from_string.h" #include "xls/dslx/parse_and_typecheck.h" +#include "xls/dslx/replace_invocations.h" #include "xls/dslx/type_system/parametric_env.h" #include "xls/dslx/type_system/type.h" #include "xls/dslx/type_system/type_info.h" @@ -1028,4 +1029,65 @@ void xls_dslx_type_dim_free(struct xls_dslx_type_dim* td) { delete cpp_type_dim; } +bool xls_dslx_replace_invocations_in_module( + struct xls_dslx_typechecked_module* tm, + struct xls_dslx_function* const callers[], size_t callers_count, + const struct xls_dslx_invocation_rewrite_rule* rules, size_t rules_count, + struct xls_dslx_import_data* import_data, const char* install_subject, + char** error_out, struct xls_dslx_typechecked_module** result_out) { + CHECK(error_out != nullptr); + CHECK(result_out != nullptr); + *error_out = nullptr; + *result_out = nullptr; + CHECK(tm != nullptr); + CHECK(import_data != nullptr); + CHECK(install_subject != nullptr); + CHECK(callers != nullptr || callers_count == 0); + CHECK(rules != nullptr || rules_count == 0); + + auto* cpp_tm = reinterpret_cast(tm); + auto* cpp_import_data = reinterpret_cast(import_data); + + std::vector callers_cpp; + callers_cpp.reserve(callers_count); + for (size_t i = 0; i < callers_count; ++i) { + CHECK(callers[i] != nullptr); + callers_cpp.push_back( + reinterpret_cast(callers[i])); + } + + std::vector rules_cpp; + rules_cpp.reserve(rules_count); + for (size_t i = 0; i < rules_count; ++i) { + const xls_dslx_invocation_rewrite_rule& r = rules[i]; + CHECK(r.from_callee != nullptr); + CHECK(r.to_callee != nullptr); + xls::dslx::InvocationRewriteRule rr; + rr.from_callee = + reinterpret_cast(r.from_callee); + rr.to_callee = reinterpret_cast(r.to_callee); + if (r.match_callee_env != nullptr) { + rr.match_callee_env = *reinterpret_cast( + r.match_callee_env); + } + if (r.to_callee_env != nullptr) { + rr.to_callee_env = + *reinterpret_cast(r.to_callee_env); + } + rules_cpp.push_back(std::move(rr)); + } + + absl::StatusOr new_tm = + xls::dslx::ReplaceInvocationsInModule( + *cpp_tm, absl::MakeSpan(callers_cpp), absl::MakeSpan(rules_cpp), + *cpp_import_data, std::string_view{install_subject}); + if (!new_tm.ok()) { + *error_out = xls::ToOwnedCString(new_tm.status().ToString()); + return false; + } + auto* heap_tm = new xls::dslx::TypecheckedModule{*std::move(new_tm)}; + *result_out = reinterpret_cast(heap_tm); + return true; +} + } // extern "C" diff --git a/xls/public/c_api_dslx.h b/xls/public/c_api_dslx.h index b23a96c838..22a282c132 100644 --- a/xls/public/c_api_dslx.h +++ b/xls/public/c_api_dslx.h @@ -82,6 +82,20 @@ struct xls_dslx_type_dim; struct xls_dslx_parametric_env; struct xls_dslx_interp_value; struct xls_bits; +// Rule for rewriting invocations in bulk API. +struct xls_dslx_invocation_rewrite_rule { + // Replace invocations whose resolved callee is exactly this function. + struct xls_dslx_function* from_callee; + // The function that should be used as the new callee. + struct xls_dslx_function* to_callee; + // Optional filter: only invocations whose callee-side ParametricEnv equals + // this value will be replaced. If null, matches all instantiations. + const struct xls_dslx_parametric_env* match_callee_env; // optional + // Optional explicit env for the replacement callee. If null, retains the + // original invocation's explicit parametrics. If non-null but empty, emits + // no explicit parametrics (rely on deduction). + const struct xls_dslx_parametric_env* to_callee_env; // optional +}; struct xls_dslx_parametric_env_item { const char* identifier; @@ -454,6 +468,16 @@ struct xls_dslx_type* xls_dslx_type_array_get_element_type( struct xls_dslx_type_dim* xls_dslx_type_array_get_size( struct xls_dslx_type* type); +// Rewrites invocations in the given module according to the rules. Returns a +// new typechecked module installed in the import graph under `install_subject`. +// On error, returns false and populates error_out (owned by caller). +bool xls_dslx_replace_invocations_in_module( + struct xls_dslx_typechecked_module* tm, + struct xls_dslx_function* const callers[], size_t callers_count, + const struct xls_dslx_invocation_rewrite_rule* rules, size_t rules_count, + struct xls_dslx_import_data* import_data, const char* install_subject, + char** error_out, struct xls_dslx_typechecked_module** result_out); + } // extern "C" #endif // XLS_PUBLIC_C_API_DSLX_H_ diff --git a/xls/public/c_api_symbols.txt b/xls/public/c_api_symbols.txt index 140c41d44a..b3262b692f 100644 --- a/xls/public/c_api_symbols.txt +++ b/xls/public/c_api_symbols.txt @@ -160,6 +160,7 @@ xls_dslx_quickcheck_get_count xls_dslx_quickcheck_get_function xls_dslx_quickcheck_is_exhaustive xls_dslx_quickcheck_to_string +xls_dslx_replace_invocations_in_module xls_dslx_struct_def_get_identifier xls_dslx_struct_def_get_member xls_dslx_struct_def_get_member_count From 6c2af003cfb3cbfc6b840eaf144674e9da91c035 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Fri, 12 Sep 2025 20:43:54 -0700 Subject: [PATCH 3/9] Update comment --- xls/dslx/replace_invocations.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xls/dslx/replace_invocations.cc b/xls/dslx/replace_invocations.cc index 6639ae83e5..efc6bc0ae5 100644 --- a/xls/dslx/replace_invocations.cc +++ b/xls/dslx/replace_invocations.cc @@ -250,8 +250,7 @@ absl::StatusOr BuildEnumParametricExpr( return static_cast(cref); } -// Dispatch point for future non-enum TypeRef support. Currently only enums are -// supported; other kinds will return an error via the enum helper. +// Currently only enums are supported; other kinds will return an error. absl::StatusOr BuildParametricExprForTypeRef( TypeRef* tr, const InterpValue& iv, TypeInfo& type_info, const Span& inv_span, Module* target_module, From 6fa2dda3afa0083461ca5db9a8868ee98f98a305 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Fri, 12 Sep 2025 20:53:22 -0700 Subject: [PATCH 4/9] Fix format --- xls/dslx/replace_invocations.h | 4 ---- xls/dslx/replace_invocations_test.cc | 4 ---- 2 files changed, 8 deletions(-) diff --git a/xls/dslx/replace_invocations.h b/xls/dslx/replace_invocations.h index 5320734d7a..4609c4e3a8 100644 --- a/xls/dslx/replace_invocations.h +++ b/xls/dslx/replace_invocations.h @@ -69,7 +69,3 @@ absl::StatusOr ReplaceInvocationsInModule( } // namespace xls::dslx #endif // XLS_DSLX_REPLACE_INVOCATIONS_H_ - - - - diff --git a/xls/dslx/replace_invocations_test.cc b/xls/dslx/replace_invocations_test.cc index 0fdfe5b5e4..1a481853e7 100644 --- a/xls/dslx/replace_invocations_test.cc +++ b/xls/dslx/replace_invocations_test.cc @@ -1230,7 +1230,3 @@ fn caller(x: u32) -> u32 { f(x) } } // namespace } // namespace xls::dslx - - - - From 175282d7d95fc39e6868bda139e75b6e9307e22e Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Wed, 17 Sep 2025 15:16:53 -0700 Subject: [PATCH 5/9] Fix code review questions --- xls/dslx/BUILD | 1 + xls/dslx/replace_invocations.cc | 10 ++---- xls/dslx/replace_invocations_test.cc | 48 ++++++++++++++-------------- xls/public/c_api_dslx.h | 1 + 4 files changed, 28 insertions(+), 32 deletions(-) diff --git a/xls/dslx/BUILD b/xls/dslx/BUILD index fc59b2d559..db93186637 100644 --- a/xls/dslx/BUILD +++ b/xls/dslx/BUILD @@ -45,6 +45,7 @@ cc_library( ":import_data", ":parse_and_typecheck", ":warning_collector", + "//xls/common:visitor", "//xls/common/status:ret_check", "//xls/dslx/frontend:ast", "//xls/dslx/frontend:ast_cloner", diff --git a/xls/dslx/replace_invocations.cc b/xls/dslx/replace_invocations.cc index efc6bc0ae5..6dc0c7b5bc 100644 --- a/xls/dslx/replace_invocations.cc +++ b/xls/dslx/replace_invocations.cc @@ -22,6 +22,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xls/common/status/ret_check.h" +#include "xls/common/visitor.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/ast_cloner.h" #include "xls/dslx/frontend/ast_utils.h" @@ -38,13 +39,6 @@ namespace xls::dslx { namespace { -template -struct Overloaded : T... { - using T::operator()...; -}; -template -Overloaded(T...) -> Overloaded; - bool MatchesCalleeEnv(const InvocationData& data, const std::optional& want_env) { if (!want_env.has_value()) { @@ -106,7 +100,7 @@ absl::StatusOr MakeColonRefSubjectFromTypeRef( TypeRef* type_ref, const Span& inv_span, Module* target_module, const absl::flat_hash_map& old_to_new) { using ReturnT = absl::StatusOr; - return absl::visit(Overloaded{ + return absl::visit(xls::Visitor{ [&](ColonRef* old_cref) -> ReturnT { auto it = old_to_new.find(old_cref); XLS_RET_CHECK(it != old_to_new.end()); diff --git a/xls/dslx/replace_invocations_test.cc b/xls/dslx/replace_invocations_test.cc index 1a481853e7..00c4393803 100644 --- a/xls/dslx/replace_invocations_test.cc +++ b/xls/dslx/replace_invocations_test.cc @@ -41,12 +41,12 @@ namespace { using ::absl_testing::StatusIs; -struct PT { +struct ParseTypecheckResult { std::unique_ptr import_data; TypecheckedModule tm; }; -absl::StatusOr ParseTypecheck(std::string text) { +absl::StatusOr ParseTypecheck(std::string text) { std::filesystem::path stdlib = std::string(::xls::kDefaultDslxStdlibPath); auto import_data = std::make_unique(CreateImportData( stdlib, /*additional_search_paths=*/std::vector{}, @@ -55,7 +55,7 @@ absl::StatusOr ParseTypecheck(std::string text) { TypecheckedModule tm, ParseAndTypecheck(text, /*path=*/"test.x", /*module_name=*/"test", import_data.get())); - return PT{.import_data = std::move(import_data), .tm = std::move(tm)}; + return ParseTypecheckResult{.import_data = std::move(import_data), .tm = std::move(tm)}; } TEST(ReplaceInvocationsTest, NonParametricSimpleReplacement) { @@ -64,7 +64,7 @@ fn a(x: u32) -> u32 { x + u32:1 } fn b(x: u32) -> u32 { x + u32:2 } fn caller(x: u32) -> u32 { b(x) + b(x) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; ASSERT_NE(m->GetFunction("caller"), std::nullopt); @@ -112,7 +112,7 @@ fn caller() -> (u8, u16) { (y8, y16) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; ASSERT_NE(m->GetFunction("caller"), std::nullopt); @@ -160,7 +160,7 @@ fn id(x: uN[N]) -> uN[N] { x } fn id2(x: uN[N]) -> uN[N] { x } fn caller() -> u8 { id(u8:1) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -201,7 +201,7 @@ fn id(x: uN[N]) -> uN[N] { x } fn id2(x: uN[N]) -> uN[N] { x } fn caller() -> u8 { id(u8:1) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -239,7 +239,7 @@ fn id(x: uN[N]) -> uN[N] { x } fn id2(x: uN[N]) -> uN[N] { x } fn caller(x: u32) -> u32 { id(x) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -279,7 +279,7 @@ fn a(x: u32) -> u32 { x + u32:1 } fn b(x: u32) -> u32 { x + u32:2 } fn caller(x: u32) -> u32 { b(x) + b(x) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -318,7 +318,7 @@ fn id(x: uN[N]) -> uN[N] { x } fn caller() -> (u8, u16) { (id(u8:1), id(u16:2)) } fn id2(x: uN[N]) -> uN[N] { x } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -346,7 +346,7 @@ fn caller() -> (u8, u16) { (id(a), id(b)) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -395,7 +395,7 @@ fn caller(x: u32) -> (u32, u32) { (a, b) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -450,7 +450,7 @@ fn f(x: u32) -> u32 { x } fn g(x: u32) -> u32 { x } fn caller(x: u32) -> u32 { f(x) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -633,7 +633,7 @@ fn id(x: uN[N]) -> uN[N] { x } fn id2(x: uN[N]) -> uN[N] { x } fn caller() -> (u8, u16) { (id(u8:1), id(u16:2)) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -676,7 +676,7 @@ fn caller() -> (u8, u16) { (id(a), id(b)) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -715,7 +715,7 @@ fn f(x: u32) -> u32 { x } fn g(x: u32) -> u32 { x } fn caller(x: u32) -> u32 { f(x) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -758,7 +758,7 @@ fn id(x: u32) -> u32 { x } fn id2(x: u32) -> u32 { x } fn caller() -> u32 { id(u32:0) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -800,7 +800,7 @@ fn f(x: u32) -> u32 { x } fn g(x: u32) -> u32 { x } fn caller() -> u32 { f(u32:0) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -826,7 +826,7 @@ fn f(x: u32) -> u32 { x } fn g(x: u32) -> u32 { x } fn caller() -> u32 { f(u32:0) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -857,7 +857,7 @@ fn f(x: u32) -> u32 { x } fn g(x: u32) -> u32 { x } fn caller(x: u32) -> u32 { f(x) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -890,7 +890,7 @@ fn h(x: u32) -> u32 { x + u32:3 } fn caller1(x: u32) -> u32 { f(x) + g(x) } fn caller2(x: u32) -> u32 { f(x) + h(x) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller1 = m->GetFunction("caller1").value(); @@ -942,7 +942,7 @@ fn id3(x: uN[N]) -> uN[N] { x } fn caller1() -> (u8, u16) { (id(u8:1), id(u16:2)) } fn caller2() -> (u16, u8) { (id(u16:3), id(u8:4)) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller1 = m->GetFunction("caller1").value(); @@ -1013,7 +1013,7 @@ fn f(x: u32) -> u32 { x } fn g(x: u32) -> u32 { x } fn caller(x: u32) -> u32 { f(x) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); @@ -1211,7 +1211,7 @@ fn f(x: u32) -> u32 { x } fn g(x: u32) -> u32 { x } fn caller(x: u32) -> u32 { f(x) } )"; - XLS_ASSERT_OK_AND_ASSIGN(PT pt, ParseTypecheck(kText)); + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; Function* caller = m->GetFunction("caller").value(); Function* f = m->GetFunction("f").value(); diff --git a/xls/public/c_api_dslx.h b/xls/public/c_api_dslx.h index 22a282c132..d2fe5646dc 100644 --- a/xls/public/c_api_dslx.h +++ b/xls/public/c_api_dslx.h @@ -82,6 +82,7 @@ struct xls_dslx_type_dim; struct xls_dslx_parametric_env; struct xls_dslx_interp_value; struct xls_bits; + // Rule for rewriting invocations in bulk API. struct xls_dslx_invocation_rewrite_rule { // Replace invocations whose resolved callee is exactly this function. From 467855e358cd18b926463319352ccb05ff4ac7b3 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Mon, 6 Oct 2025 14:50:57 -0700 Subject: [PATCH 6/9] Fix one test for replace_invocations --- xls/dslx/replace_invocations_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xls/dslx/replace_invocations_test.cc b/xls/dslx/replace_invocations_test.cc index 00c4393803..49a073c797 100644 --- a/xls/dslx/replace_invocations_test.cc +++ b/xls/dslx/replace_invocations_test.cc @@ -1207,9 +1207,10 @@ TEST(ReplaceInvocationsTest, ParametricToEnvNonEnumTypeAnnotationErrors) { struct S { a: u32, } +const DEFAULT_S = S { a: u32:0 }; fn f(x: u32) -> u32 { x } fn g(x: u32) -> u32 { x } -fn caller(x: u32) -> u32 { f(x) } +fn caller(x: u32) -> u32 { f(x) } )"; XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); Module* m = pt.tm.module; From d32a6d871d17b9d9187a22fdb92ccc7e82fd33ca Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Tue, 7 Oct 2025 10:13:01 -0700 Subject: [PATCH 7/9] Add another test and fix on higher order builtins --- xls/dslx/replace_invocations.cc | 35 ++++++++++++-- xls/dslx/replace_invocations_test.cc | 70 ++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/xls/dslx/replace_invocations.cc b/xls/dslx/replace_invocations.cc index 6dc0c7b5bc..b1ef04b0d4 100644 --- a/xls/dslx/replace_invocations.cc +++ b/xls/dslx/replace_invocations.cc @@ -14,8 +14,10 @@ #include "xls/dslx/replace_invocations.h" +#include #include #include +#include #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" @@ -434,9 +436,7 @@ absl::StatusOr ReplaceInvocationsInModule( auto it_nd = old_to_new.find(old_target); XLS_RET_CHECK(it_nd != old_to_new.end()); auto* new_target = down_cast(it_nd->second); - NameRef* new_callee = target_module->Make( - inv->callee()->span(), new_target->identifier(), new_target, - inv->callee()->in_parens()); + const NameDef* from_name_def = matched_rule->from_callee->name_def(); // Pre-built the name def map to avoid rebuilding it for each invocation. absl::flat_hash_map name_def_map = @@ -446,9 +446,36 @@ absl::StatusOr ReplaceInvocationsInModule( return CloneExprIntoModule(e, target_module, old_to_new, name_def_map); }; + auto maybe_replace_expr = [&](Expr* candidate) -> std::optional { + auto* nr = dynamic_cast(candidate); + if (nr == nullptr || nr->IsBuiltin()) { + return std::nullopt; + } + if (!std::holds_alternative(nr->name_def()) || + std::get(nr->name_def()) != from_name_def) { + return std::nullopt; + } + return std::optional(target_module->Make( + nr->span(), new_target->identifier(), new_target, nr->in_parens())); + }; + + Expr* original_callee = inv->callee(); + Expr* new_callee = nullptr; + if (std::optional replaced = maybe_replace_expr(original_callee); + replaced.has_value()) { + new_callee = *replaced; + } else { + XLS_ASSIGN_OR_RETURN(new_callee, clone_expr_into(original_callee)); + } + std::vector new_args; new_args.reserve(inv->args().size()); for (Expr* arg : inv->args()) { + if (std::optional replaced = maybe_replace_expr(arg); + replaced.has_value()) { + new_args.push_back(*replaced); + continue; + } XLS_ASSIGN_OR_RETURN(Expr * cloned, clone_expr_into(arg)); new_args.push_back(cloned); } @@ -480,6 +507,8 @@ absl::StatusOr ReplaceInvocationsInModule( CloneModule(module, std::move(replacer))); XLS_RET_CHECK_OK(VerifyClone(&module, cloned.get(), *module.file_table())); + std::cout << "Cloned module: " << cloned->ToString(); + return TypecheckAndInstallCloned(std::move(cloned), tm, import_data, install_subject); } diff --git a/xls/dslx/replace_invocations_test.cc b/xls/dslx/replace_invocations_test.cc index 49a073c797..a1e2e0da51 100644 --- a/xls/dslx/replace_invocations_test.cc +++ b/xls/dslx/replace_invocations_test.cc @@ -153,6 +153,76 @@ fn caller() -> (u8, u16) { EXPECT_EQ(id_uses, 1); } +TEST(ReplaceInvocationsTest, HigherOrderMapKeepsOuterCallee) { + const std::string kText = R"(// test +fn predicate(xy: (u6, u6)) -> bool { + let (pattern, payload) = xy; + pattern == payload +} + +fn predicate_alias(xy: (u6, u6)) -> bool { + predicate(xy) +} + +fn repeat(x: u6) -> u6[N] { + u6[N]:[x, ...] +} + +fn select_mask(xs: u6[2], selector: u6) -> bool[2] { + let reps = repeat(selector); + let pairs = zip(xs, reps); + map(pairs, predicate) +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(ParseTypecheckResult pt, ParseTypecheck(kText)); + Module* m = pt.tm.module; + + Function* select_mask = m->GetFunction("select_mask").value(); + Function* predicate = m->GetFunction("predicate").value(); + Function* predicate_alias = m->GetFunction("predicate_alias").value(); + + InvocationRewriteRule rule; + rule.from_callee = predicate; + rule.to_callee = predicate_alias; + + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule new_tm, + ReplaceInvocationsInModule(pt.tm, select_mask, rule, *pt.import_data, + "test.rw")); + Function* select_mask_new = + new_tm.module->GetFunction("select_mask").value(); + + int map_invocations = 0; + int rewritten_invocations = 0; + int predicate_alias_refs = 0; + int predicate_refs = 0; + XLS_ASSERT_OK_AND_ASSIGN( + auto nodes, CollectUnder(select_mask_new->body(), /*want_types=*/false)); + for (AstNode* n : nodes) { + if (auto* inv = dynamic_cast(n)) { + std::string callee = inv->callee()->ToString(); + if (callee == "map") { + map_invocations++; + } + if (callee == "predicate_alias") { + rewritten_invocations++; + } + } + if (auto* nr = dynamic_cast(n)) { + if (nr->identifier() == "predicate_alias") { + predicate_alias_refs++; + } + if (nr->identifier() == "predicate") { + predicate_refs++; + } + } + } + EXPECT_EQ(map_invocations, 1); + EXPECT_EQ(rewritten_invocations, 0); + EXPECT_GE(predicate_alias_refs, 1); + EXPECT_EQ(predicate_refs, 0); +} + TEST(ReplaceInvocationsTest, ParametricReplacementNoToEnvRetainsExplicitParams) { const std::string kText = R"(// test From 4f5d627bbfbcdf1430e05a955ac5cc8debfdc84a Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Tue, 7 Oct 2025 11:30:03 -0700 Subject: [PATCH 8/9] Remove printing --- xls/dslx/replace_invocations.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/xls/dslx/replace_invocations.cc b/xls/dslx/replace_invocations.cc index b1ef04b0d4..011f4189b8 100644 --- a/xls/dslx/replace_invocations.cc +++ b/xls/dslx/replace_invocations.cc @@ -507,8 +507,6 @@ absl::StatusOr ReplaceInvocationsInModule( CloneModule(module, std::move(replacer))); XLS_RET_CHECK_OK(VerifyClone(&module, cloned.get(), *module.file_table())); - std::cout << "Cloned module: " << cloned->ToString(); - return TypecheckAndInstallCloned(std::move(cloned), tm, import_data, install_subject); } From 3785fb451e120ff97b9f2b68d374eb604b6e92a0 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Tue, 7 Oct 2025 11:37:34 -0700 Subject: [PATCH 9/9] Change to use import --- xls/dslx/replace_invocations_test.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xls/dslx/replace_invocations_test.cc b/xls/dslx/replace_invocations_test.cc index a1e2e0da51..ae9a0791cd 100644 --- a/xls/dslx/replace_invocations_test.cc +++ b/xls/dslx/replace_invocations_test.cc @@ -1210,8 +1210,7 @@ pub enum E : u2 { } )"; const std::string kTest = R"(// test.x -#![feature(use_syntax)] -use base; +import base; fn f(x: u32) -> u32 { x } fn g(x: u32) -> u32 { x } fn caller(x: u32) -> u32 { f(x) }