diff --git a/CMakeLists.txt b/CMakeLists.txt index 2dd1d46e0..3666cbd9d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -225,6 +225,7 @@ message(STATUS "Found ATen.so file: ${ATEN_LIBRARIES}") ################################################################################ # use locally generated C++ bindings include_directories(AFTER ${PROJECT_SOURCE_DIR}/isl_interface/include) +include_directories(AFTER ${CMAKE_CURRENT_BINARY_DIR}/isl_interface/include) include_directories(AFTER ${PROJECT_SOURCE_DIR}/third-party/islpp/include) include_directories(AFTER ${CMAKE_CURRENT_BINARY_DIR}/third-party/islpp/include) add_subdirectory(external/isl) @@ -334,6 +335,34 @@ else() message(STATUS "Not building benchmarks, caffe2 or CUDA not available") endif() +SET(ISL_CPP_H "${CMAKE_CURRENT_LIST_DIR}/isl_interface/include/isl/cpp.h") + +add_executable(generate_template_isl isl_interface/generate_template_isl.cc) + +find_program(CLANG_FORMAT_BIN clang-format PATHS ${CLANG_PREFIX} + PATH_SUFFIXES bin + NO_DEFAULT_PATH) + +SET(ISL_TEMPLATE_CPP_DIR + "${CMAKE_CURRENT_BINARY_DIR}/isl_interface/include/isl") +SET(ISL_TEMPLATE_CPP_H "${ISL_TEMPLATE_CPP_DIR}/template_cpp.h") +add_custom_command( + OUTPUT ${ISL_TEMPLATE_CPP_H} + DEPENDS ${ISL_CPP_H} + DEPENDS generate_template_isl + COMMAND mkdir -p ${ISL_TEMPLATE_CPP_DIR} + COMMAND generate_template_isl < ${ISL_CPP_H} > ${ISL_TEMPLATE_CPP_H} + COMMAND ${CLANG_FORMAT_BIN} -i ${ISL_TEMPLATE_CPP_H} +) + if (WITH_BINDINGS) add_subdirectory(isl_interface) + + # generate_isl_cpp_h is the dependency that should be used + # by code that depends on the isl C++ bindings. + add_custom_target(generate_isl_cpp_h + DEPENDS generate_isl_cpp_h_core ${ISL_TEMPLATE_CPP_H}) +else() + add_custom_target(generate_isl_cpp_h + DEPENDS ${ISL_TEMPLATE_CPP_H}) endif() diff --git a/isl_interface/CMakeLists.txt b/isl_interface/CMakeLists.txt index e51621299..5c0772f16 100644 --- a/isl_interface/CMakeLists.txt +++ b/isl_interface/CMakeLists.txt @@ -86,7 +86,6 @@ target_link_libraries(extract_isl_interface # Dummy library to ensure that C++ bindings depend on contents of header files. add_library(isl_all_h_dep STATIC ${ISL_DIR}/all.c) -SET(ISL_CPP_H "${CMAKE_CURRENT_LIST_DIR}/include/isl/cpp.h") add_custom_command( OUTPUT ${ISL_CPP_H} DEPENDS isl_all_h_dep @@ -106,6 +105,5 @@ add_custom_command( COMMAND cat ${ISL_DIR}/cpp/cpp.h.bot >> ${ISL_CPP_H} || exit 1 DEPENDS extract_isl_interface ) -# generate_isl_cpp_h is the dependency that should be used -# by code that depends on the isl C++ bindings. -add_custom_target(generate_isl_cpp_h DEPENDS ${ISL_CPP_H}) + +add_custom_target(generate_isl_cpp_h_core DEPENDS ${ISL_CPP_H}) diff --git a/isl_interface/generate_template_isl.cc b/isl_interface/generate_template_isl.cc new file mode 100644 index 000000000..7a8d9393e --- /dev/null +++ b/isl_interface/generate_template_isl.cc @@ -0,0 +1,1135 @@ +/** + * Copyright (c) 2018, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +constexpr auto header = R"CPP( +struct Anonymous; + +template +struct NamedPair; + +template +using Pair = NamedPair; +)CPP"; + +constexpr auto footer = R"CPP( +template +using AffOn = Aff; + +template +using PwAffOn = PwAff; + +template +using UnionPwAffOn = UnionPwAff; + +template +using AffListOn = AffList; + +template +using PwAffListOn = PwAffList; + +template +using UnionPwAffListOn = UnionPwAffList; +)CPP"; + +static std::string dropIslNamespace(std::string type) { + return std::regex_replace(type, std::regex("isl::"), ""); +} + +template +struct Signature { + Type returnType; + std::vector argTypes; +}; + +using Type = std::string; +struct BaseKind { + BaseKind(const char* name) : name(name) {} + BaseKind(std::initializer_list c) : children(c) { + if (c.size() == 2) { + children.insert(children.begin(), "Anonymous"); + } + if (children.size() != 3) { + abort(); + } + } + std::string name; + std::vector children; + bool operator<(const BaseKind& other) const { + if (children.size() != other.children.size()) { + return children.size() < other.children.size(); + } + if (children.size() == 0) { + return name < other.name; + } + for (size_t i = 0; i < children.size(); ++i) { + if (children[i] != other.children[i]) { + return children[i] < other.children[i]; + } + } + return false; + } + bool operator==(const BaseKind& other) const { + return name == other.name && children == other.children; + } + bool operator!=(const BaseKind& other) const { + return !(*this == other); + } +}; +using Kind = std::vector; + +struct Method { + std::string name; + Signature signature; +}; + +using Exported = std::unordered_map>; + +struct Class { + std::string name; + std::vector kinds; +}; + +static Kind params_type() { + return {}; +} + +static Kind set_type() { + return {"Domain"}; +} + +static Kind map_type() { + return {"Domain", "Range"}; +} + +static const std::unordered_map classes{ + {"space", {"Space", {params_type(), set_type(), map_type()}}}, + {"multi_id", {"MultiId", {set_type()}}}, + {"multi_val", {"MultiVal", {set_type()}}}, + {"set", {"Set", {params_type(), set_type()}}}, + {"map", {"Map", {map_type()}}}, + {"aff", {"Aff", {set_type(), map_type()}}}, + {"aff_list", {"AffList", {set_type(), map_type()}}}, + {"pw_aff", {"PwAff", {set_type(), map_type()}}}, + {"union_pw_aff", {"UnionPwAff", {set_type(), map_type()}}}, + {"multi_aff", {"MultiAff", {map_type()}}}, + {"pw_aff_list", {"PwAffList", {set_type(), map_type()}}}, + {"union_pw_aff_list", {"UnionPwAffList", {set_type(), map_type()}}}, + {"multi_union_pw_aff", {"MultiUnionPwAff", {map_type()}}}, + {"union_pw_multi_aff", {"UnionPwMultiAff", {map_type()}}}, + {"union_set", {"UnionSet", {params_type(), set_type()}}}, + {"union_map", {"UnionMap", {map_type()}}}, + {"map_list", {"MapList", {map_type()}}}, + {"union_access_info", {"UnionAccessInfo", {map_type()}}}, + {"union_flow", {"UnionFlow", {map_type()}}}, + {"stride_info", {"StrideInfo", {map_type()}}}, + {"fixed_box", {"FixedBox", {map_type()}}}, +}; + +static Signature create_params() { + return {{}, {}}; +} + +static Signature create_set() { + return {{"Domain"}, {}}; +} + +static Signature change_set() { + return {{"ModifiedDomain"}, {{"Domain"}}}; +} + +static Signature change_wrapped_set() { + return {{{"ModifiedWrap", "WrappedDomain", "WrappedRange"}}, + {{{"Wrap", "WrappedDomain", "WrappedRange"}}}}; +} + +static Signature change_range() { + return {{"Domain", "ModifiedRange"}, {{"Domain", "Range"}}}; +} + +static Signature create_map() { + return {{"Domain", "Range"}, {}}; +} + +static Signature apply_set() { + return {{"Range"}, {{"Domain"}, {"Domain", "Range"}}}; +} + +static Signature apply_domain() { + return {{"Range3", "Range"}, {{"Domain", "Range"}, {"Domain", "Range3"}}}; +} + +static Signature preimage_domain() { + return {{"Domain2", "Range"}, {{"Domain", "Range"}, {"Domain2", "Domain"}}}; +} + +static Signature apply_range() { + return {{"Domain", "Range2"}, {{"Domain", "Range"}, {"Range", "Range2"}}}; +} + +static Signature modify_params_unary() { + return {{}, {{}}}; +} + +static Signature modify_params_binary() { + return {{}, {{}, {}}}; +} + +static Signature modify_set_params() { + return {{"Domain"}, {{"Domain"}, {}}}; +} + +static Signature modify_set_unary() { + return {{"Domain"}, {{"Domain"}}}; +} + +static Signature modify_set_binary() { + return {{"Domain"}, {{"Domain"}, {"Domain"}}}; +} + +static Signature modify_domain() { + return {{"Domain", "Range"}, {{"Domain", "Range"}, {"Domain"}}}; +} + +static Signature modify_range() { + return {{"Domain", "Range"}, {{"Domain", "Range"}, {"Range"}}}; +} + +static Signature modify_map_params() { + return {{"Domain", "Range"}, {{"Domain", "Range"}, {}}}; +} + +static Signature modify_map() { + return {{"Domain", "Range"}, {{"Domain", "Range"}, {"Domain", "Range"}}}; +} + +static Signature modify_map_unary() { + return {{"Domain", "Range"}, {{"Domain", "Range"}}}; +} + +static Signature map_from_set() { + return {{"Domain", "Domain"}, {{"Domain"}}}; +} + +static Signature map_on_domain() { + return {{"Domain", "Domain"}, {{"Domain", "Range"}}}; +} + +static Signature map_from_domain() { + return {{"Domain", "Range"}, {{"Domain"}}}; +} + +static Signature set_from_params() { + return {{"Domain"}, {{}}}; +} + +static Signature map_from_params() { + return {{"Domain", "Range"}, {{}}}; +} + +static Signature set_params() { + return {{}, {{"Domain"}}}; +} + +static Signature map_params() { + return {{}, {{"Domain", "Range"}}}; +} + +static Signature set_binary_params() { + return {{}, {{"Domain"}, {"Domain"}}}; +} + +static Signature domain() { + return {{"Domain"}, {{"Domain", "Range"}}}; +} + +static Signature domain_map() { + return {{{"Domain", "Range"}, "Domain"}, {{"Domain", "Range"}}}; +} + +static Signature static_domain_map() { + return {{{"Range", "Domain2"}, "Range"}, {{"Range", "Domain2"}}}; +} + +static Signature static_range_map() { + return {{{"Domain2", "Range"}, "Range"}, {{"Domain2", "Range"}}}; +} + +static Signature static_wrapped_range_map() { + return {{{"Wrap", "WrappedDomain", "WrappedRange"}, "WrappedRange"}, + {{{"Wrap", "WrappedDomain", "WrappedRange"}}}}; +} + +static Signature domain_binary() { + return {{"Domain"}, {{"Domain", "Range"}, {"Domain", "Range"}}}; +} + +static Signature domain_binary_map() { + return {{"Domain", "Domain2"}, {{"Domain", "Range"}, {"Domain2", "Range"}}}; +} + +static Signature range() { + return {{"Range"}, {{"Domain", "Range"}}}; +} + +static Signature from_domain_and_range() { + return {{"Domain", "Range"}, {{"Domain"}, {"Range"}}}; +} + +static Signature from_range_and_domain() { + return {{"Domain2", "Domain"}, {{"Domain"}, {"Domain2"}}}; +} + +static Signature reverse() { + return {{"Range", "Domain"}, {{"Domain", "Range"}}}; +} + +static Signature test_map() { + return {{"Domain", "Domain"}, {{"Domain", "Domain"}, {"Domain", "Range2"}}}; +} + +static Signature range_product() { + return {{"Domain", {"Range", "Range3"}}, + {{"Domain", "Range"}, {"Domain", "Range3"}}}; +} + +static Signature flat_range_product() { + return {{"Domain", "Range2"}, {{"Domain", "Range"}, {"Domain", "Range3"}}}; +} + +static Signature curry() { + return {{"Domain2", {"Range2", "Range"}}, {{{"Domain2", "Range2"}, "Range"}}}; +} + +static Signature uncurry() { + return {{{"Domain", "Domain2"}, "Range2"}, + {{"Domain", {"Domain2", "Range2"}}}}; +} + +static Signature domain_factor_domain() { + return {{"Domain2", "Range"}, {{{"Domain2", "Range2"}, "Range"}}}; +} + +static Signature wrap() { + return {{{"Domain", "Range"}}, {{"Domain", "Range"}}}; +} + +static Signature wrap_binary() { + return {{{"Domain", "Range"}}, {{"Domain"}, {"Range"}}}; +} + +static Signature unwrap() { + return {{"MapDomain", "MapRange"}, {{{"MapDomain", "MapRange"}}}}; +} + +static Signature zip() { + return {{{"WrappedDomain1", "WrappedDomain2"}, + {"WrappedRange1", "WrappedRange2"}}, + {{{"WrappedDomain1", "WrappedRange1"}, + {"WrappedDomain2", "WrappedRange2"}}}}; +} + +static Signature get_map_anonymous() { + return {{"Domain", "Anonymous"}, {{"Domain", "Range"}}}; +} + +static Signature add_map_anonymous() { + return {{"Domain", "Range"}, {{"Domain", "Range"}, {"Domain", "Anonymous"}}}; +} + +static Signature add_range_anonymous() { + return {{"Domain", "Range"}, {{"Range"}, {"Domain", "Anonymous"}}}; +} + +static const std::unordered_map>> + signatures{ + {"add_param", {modify_params_unary(), modify_set_unary()}}, + {"align_params", {modify_set_params(), modify_map_params()}}, + {"apply", {apply_set(), apply_range()}}, + {"apply_domain", {apply_domain()}}, + {"preimage_domain", {preimage_domain()}}, + {"pullback", {preimage_domain()}}, + {"apply_range", {apply_range()}}, + {"coalesce", + {modify_params_unary(), modify_set_unary(), modify_map_unary()}}, + {"eq_at", {test_map()}}, + {"ge_set", {set_binary_params()}}, + {"get_space", + {modify_params_unary(), modify_set_unary(), modify_map_unary()}}, + {"gist", {modify_set_binary(), modify_map()}}, + {"intersect", + {modify_params_binary(), modify_set_binary(), modify_map()}}, + {"intersect_domain", {modify_domain()}}, + {"intersect_range", {modify_range()}}, + {"intersect_params", {modify_set_params(), modify_map_params()}}, + {"lt_set", {domain_binary()}}, + {"le_set", {domain_binary()}}, + {"eq_set", {domain_binary()}}, + {"lt_map", {domain_binary_map()}}, + {"gt_map", {domain_binary_map()}}, + {"params", {set_params(), map_params()}}, + {"from_params", {set_from_params()}}, + {"map_from_set", {map_from_set()}}, + {"add_named_tuple_id_ui", {set_from_params()}}, + {"add_unnamed_tuple_ui", {set_from_params(), map_from_domain()}}, + {"product", {wrap_binary()}}, + {"map_from_domain_and_range", {from_domain_and_range()}}, + {"domain", {domain()}}, + {"domain_map", {domain_map()}}, + {"range", {range()}}, + {"reverse", {reverse()}}, + {"subtract", {modify_set_binary(), modify_map()}}, + {"unbind_params_insert_domain", {from_range_and_domain()}}, + {"sum", {modify_map()}}, + {"unite", {modify_set_binary(), modify_map()}}, + {"union_add", {modify_map()}}, + {"range_product", {range_product()}}, + {"flat_range_product", {flat_range_product()}}, + {"curry", {curry()}}, + {"uncurry", {uncurry()}}, + {"domain_factor_domain", {domain_factor_domain()}}, + {"wrap", {wrap()}}, + {"unwrap", {unwrap()}}, + {"zip", {zip()}}, + {"add", {modify_set_binary(), modify_map()}}, + {"sub", {modify_set_binary(), modify_map()}}, + {"mul", {modify_set_binary()}}, + {"div", {modify_set_binary()}}, + {"mod", + {modify_set_binary(), + modify_map(), + modify_set_unary(), + modify_map_unary()}}, + {"get_at", {modify_set_unary(), modify_map_unary()}}, + {"get_map_list", {modify_map_unary()}}, + {"get_aff", {get_map_anonymous()}}, + {"set_aff", {add_map_anonymous()}}, + {"get_aff_list", {get_map_anonymous()}}, + {"get_union_pw_aff", {get_map_anonymous()}}, + {"set_union_pw_aff", {add_map_anonymous()}}, + {"get_union_pw_aff_list", {get_map_anonymous()}}, + {"pos_set", {set_params(), domain()}}, + {"nonneg_set", {set_params(), domain()}}, + {"zero_set", {set_params(), domain()}}, + {"zero_union_set", {set_params(), domain()}}, + {"add_constant", {modify_set_unary(), modify_map_unary()}}, + {"add_constant_si", {modify_set_unary(), modify_map_unary()}}, + {"set_val", {modify_set_unary()}}, + {"floor", {modify_map_unary()}}, + {"neg", {modify_set_unary(), modify_map_unary()}}, + {"drop", {modify_map_unary()}}, + {"scale", {modify_map_unary(), modify_range()}}, + {"scale_down", {modify_map_unary(), modify_range()}}, + {"set_set_tuple_id", {change_wrapped_set(), change_set()}}, + {"set_tuple_id", {change_wrapped_set(), change_set()}}, + {"set_range_tuple_id", {change_set(), change_range()}}, + {"get_range_stride_info", {get_map_anonymous()}}, + {"get_offset", {modify_map_unary()}}, + {"get_range_simple_fixed_box_hull", {modify_map_unary()}}, + {"set_may_source", {modify_map()}}, + {"set_schedule", {modify_map_unary()}}, + {"compute_flow", {modify_map_unary()}}, + {"get_may_dependence", {map_on_domain()}}, + }; + +static const std:: + map, std::vector>> + specificSignatures{ + {{"set", "identity"}, {map_from_set()}}, + {{"union_set", "get_space"}, {set_params()}}, + {{"union_map", "get_space"}, {set_params()}}, + {{"union_pw_aff", "get_space"}, {set_params()}}, + {{"union_pw_multi_aff", "get_space"}, {set_params()}}, + {{"union_set", "universe"}, {modify_set_unary()}}, + {{"union_map", "universe"}, {modify_map_unary()}}, + // should be called "gist_domain" + {{"multi_union_pw_aff", "gist"}, {modify_domain()}}, + {{"multi_union_pw_aff", "get_space"}, {range()}}, + {{"aff_list", "reverse"}, {modify_map_unary()}}, + {{"pw_aff_list", "reverse"}, {modify_map_unary()}}, + {{"union_pw_aff_list", "reverse"}, {modify_map_unary()}}, + {{"map_list", "reverse"}, {modify_map_unary()}}, + }; + +static const std::unordered_map>> + staticSignatures{ + {"from", {modify_map_unary()}}, + {"identity", {modify_map_unary()}}, + {"param_on_domain_space", {set_from_params()}}, + {"param_on_domain", {map_from_domain()}}, + {"empty", + {modify_params_unary(), modify_set_unary(), modify_map_unary()}}, + {"universe", + {modify_params_unary(), modify_set_unary(), modify_map_unary()}}, + {"zero", {modify_set_unary(), modify_map_unary()}}, + {"zero_on_domain", {map_from_domain()}}, + {"from_domain", {map_from_domain()}}, + }; + +static const std:: + map, std::vector>> + specificStaticSignatures{ + {{"multi_aff", "domain_map"}, {static_domain_map()}}, + {{"multi_aff", "range_map"}, {static_range_map()}}, + {{"multi_aff", "wrapped_range_map"}, {static_wrapped_range_map()}}, + {{"union_set", "empty"}, {set_from_params()}}, + {{"union_map", "empty"}, {map_from_params()}}, + }; + +struct Constructor { + std::vector argTypes; + std::vector> signatures; +}; + +static const std::unordered_map> + constructors{ + {"multi_id", {{{"space", "id_list"}, {modify_set_unary()}}}}, + {"multi_val", {{{"space", "val_list"}, {modify_set_unary()}}}}, + {"multi_aff", {{{"space", "aff_list"}, {add_map_anonymous()}}}}, + {"union_pw_aff", {{{"union_set", "val"}, {map_from_domain()}}}}, + {"multi_union_pw_aff", + {{{"space", "union_pw_aff_list"}, {add_range_anonymous()}}, + {{"union_set", "multi_val"}, {from_domain_and_range()}}}}, + {"map", {{{"multi_aff"}, {modify_map_unary()}}}}, + {"union_map", {{{"map"}, {modify_map_unary()}}}}, + {"union_set", {{{"set"}, {modify_set_unary()}}}}, + {"pw_aff", {{{"aff"}, {modify_set_unary(), modify_map_unary()}}}}, + {"aff_list", + {{{"aff"}, {modify_set_unary(), modify_map_unary()}}, + {{"ctx", "int"}, {create_set(), create_map()}}}}, + // should be replaced by constructor without int argument + {"space", {{{"ctx", "int"}, {create_params()}}}}, + {"union_pw_aff_list", {{{"ctx", "int"}, {create_set(), create_map()}}}}, + {"union_access_info", {{{"union_map"}, {modify_map_unary()}}}}, + }; + +static bool isForeach(const std::string& name) { + return name.find("foreach_") != std::string::npos; +} + +using Subs = std::map; + +static std::set +collect(const Kind& kind, const Subs& subs, std::set set = {}); + +static std::set collect( + const BaseKind& base, + const Subs& subs, + std::set set = {}) { + if (base.children.size() == 0) { + if (subs.count(base.name) != 0) { + set = collect(subs.at(base.name), {}, set); + } else if (base.name != "Anonymous") { + set.insert(base.name); + } + } else { + for (const auto& el : base.children) { + set = collect(el, subs, set); + } + } + return set; +} + +static std::set +collect(const Kind& kind, const Subs& subs, std::set set) { + for (auto base : kind) { + set = collect(base, subs, set); + } + return set; +} + +static std::set collect( + const Signature& signature, + const Subs& subs) { + auto set = collect(signature.returnType, subs); + for (auto arg : signature.argTypes) { + set = collect(arg, subs, set); + } + return set; +} + +static void printTemplateList( + const std::set set, + const std::string& qualifier) { + std::cout << "<"; + bool first = true; + for (auto s : set) { + if (!first) { + std::cout << ", "; + } + std::cout << qualifier << s; + first = false; + } + std::cout << ">"; +} + +static void +print(std::ostream& os, const BaseKind& base, const Subs& subs = {}); + +static void +print(std::ostream& os, const std::string& s, const Subs& subs = {}) { + if (subs.count(s) != 0) { + print(os, subs.at(s)); + } else { + os << s; + } +} + +static void print(std::ostream& os, const BaseKind& base, const Subs& subs) { + if (base.children.size() == 3) { + if (base.children[0] == "Anonymous") { + os << "Pair<"; + } else { + os << "NamedPair<"; + print(os, base.children[0], subs); + os << ","; + } + print(os, base.children[1], subs); + os << ","; + print(os, base.children[2], subs); + os << ">"; + } else { + print(os, base.name, subs); + } +} + +template +static void printTemplateList( + const std::vector list, + const std::string& qualifier, + const Subs& subs = {}) { + std::cout << "<"; + for (unsigned i = 0; i < list.size(); ++i) { + if (i > 0) { + std::cout << ", "; + } + std::cout << qualifier; + print(std::cout, list[i], subs); + } + std::cout << ">"; +} + +template +static void printTemplate(const T& t) { + std::cout << "template "; + printTemplateList(t, "typename "); + std::cout << "\n"; +} + +static void printClassDeclaration(const std::string& name, const Kind& kind) { + printTemplate(collect(kind, {})); + std::cout << "struct " << name; + printTemplateList(kind, ""); +} + +static void printForwardDeclarations() { + for (auto kvp : classes) { + std::cout << "\n"; + std::cout << "template \n"; + std::cout << "struct " << kvp.second.name; + std::cout << ";\n"; + } +} + +static BaseKind specialize(const BaseKind& base, const Subs& subs) { + if (base.children.size() == 0) { + if (subs.count(base.name) != 0) { + return subs.at(base.name); + } else { + return base; + } + } else { + return BaseKind{specialize(base.children[0], subs), + specialize(base.children[1], subs), + specialize(base.children[2], subs)}; + } +} + +static Kind specialize(const Kind& kind, const Subs& subs) { + if (subs.size() == 0) { + return kind; + } + Kind specialized; + for (auto base : kind) { + specialized.emplace_back(specialize(base, subs)); + } + return specialized; +} + +static std::vector specialize( + const std::vector& vector, + const Subs& subs) { + std::vector specialized; + for (auto kind : vector) { + specialized.emplace_back(specialize(kind, subs)); + } + return specialized; +} + +static Signature specialize( + const Signature& signature, + const Subs& subs) { + return {specialize(signature.returnType, subs), + specialize(signature.argTypes, subs)}; +} + +static void printExtraTemplate( + const Kind& classKind, + const Signature& signature, + const Subs& subs, + bool isStatic) { + auto classBase = collect(classKind, {}); + classBase.insert("Anonymous"); + auto signatureBase = collect(signature, subs); + std::vector extra; + for (auto base : signatureBase) { + if (classBase.count(base) == 0) { + extra.emplace_back(base); + } + } + if (extra.size() != 0) { + printTemplate(extra); + } +} + +static void +printType(const Type& type, const Kind& kind, const Subs& subs = {}) { + if (classes.count(type) == 0) { + std::cout << type; + } else { + const auto& returnType = classes.at(type); + std::cout << returnType.name; + printTemplateList(kind, "", subs); + } +} + +static void printReturnType( + const Signature& signature, + const Method& method, + const Subs& subs = {}) { + printType(method.signature.returnType, signature.returnType, subs); +} + +static Subs specializer( + const std::vector& dst, + const std::vector& src, + Subs subs = {}) { + for (size_t i = 0; i < src.size(); ++i) { + if (src[i].children.size() == 0) { + subs.emplace(src[i].name, dst[i]); + } else if (src[i].children.size() == dst[i].children.size()) { + subs = specializer(dst[i].children, src[i].children, subs); + } + } + return subs; +} + +static Signature specialize( + const Signature& signature, + const Kind& classKind) { + Subs subs = specializer(classKind, signature.argTypes[0]); + return specialize(signature, subs); +} + +static bool printMethod( + const std::string& base, + const Kind& classKind, + const Signature& signature, + const Method& method, + const Subs& subs, + bool isStatic = false) { + auto specializedSignature = + isStatic ? signature : specialize(signature, classKind); + const auto& match = isStatic ? specializedSignature.returnType + : specializedSignature.argTypes[0]; + auto specializedMatch = specialize(match, subs); + if (specializedMatch != classKind) { + return false; + } + printExtraTemplate(classKind, specializedSignature, subs, isStatic); + if (isStatic) { + std::cout << "static "; + } + std::cout << "inline "; + printReturnType(specializedSignature, method, subs); + std::cout << " "; + std::cout << method.name; + std::cout << "("; + size_t j = isStatic ? 0 : 1; + for (size_t i = 0; i < method.signature.argTypes.size(); ++i) { + if (i > 0) { + std::cout << ", "; + } + std::cout << "const "; + const Type& type = method.signature.argTypes[i]; + if (classes.count(type) == 0) { + std::cout << type; + } else { + printType(type, specializedSignature.argTypes[j++], subs); + } + std::cout << "& arg" << i; + } + std::cout << ")"; + if (!isStatic) { + std::cout << " const"; + } + std::cout << " {\n"; + std::cout << "auto res = "; + if (!isStatic) { + std::cout << "this->"; + } + std::cout << base << "::" << method.name << "("; + for (size_t i = 0; i < method.signature.argTypes.size(); ++i) { + if (i > 0) { + std::cout << ", "; + } + std::cout << "arg" << i; + } + std::cout << ");\n"; + std::cout << "return "; + printReturnType(specializedSignature, method, subs); + std::cout << "(res);\n"; + std::cout << "}\n"; + return true; +} + +static void printTo( + const Signature& signature, + const Method& method, + const Subs& subs) { + std::cout << "inline "; + printReturnType(signature, method, subs); + std::cout << " to" << classes.at(method.signature.returnType).name + << "() const {\n"; + std::cout << "return "; + printReturnType(signature, method, subs); + std::cout << "::from(*this);\n"; + std::cout << "}\n"; +} + +static void printAs( + const Signature& signature, + const Method& method, + const Subs& subs) { + std::cout << "inline "; + printReturnType(signature, method, subs); + std::cout << " as" << classes.at(method.signature.returnType).name + << "() const {\n"; + std::cout << "return "; + printReturnType(signature, method, subs); + std::cout << "(*this);\n"; + std::cout << "}\n"; +} + +static void printForeach( + const std::string& base, + const Kind& classKind, + const Method& method) { + const auto& fn = method.signature.argTypes[0]; + auto open = fn.find("("); + auto close = fn.find(")", open + 1); + if (close == std::string::npos) { + return; + } + auto argType = fn.substr(open + 1, close - (open + 1)); + if (classes.count(argType) == 0) { + return; + } + std::cout << "inline void " << method.name << "("; + std::cout << fn.substr(0, open + 1); + printType(argType, classKind); + std::cout << fn.substr(close); + std::cout << "& fn) const {\n"; + std::cout << "auto lambda = [fn](" << argType << " arg) -> void {\n"; + std::cout << "fn("; + printType(argType, classKind); + std::cout << "(arg));"; + std::cout << "};\n"; + std::cout << "this->" << base << "::" << method.name << "(lambda);\n"; + std::cout << "}\n"; +} + +static bool matches( + const Kind& classKind, + const Signature& signature, + const Method& method) { + if (signature.argTypes[0].size() != classKind.size()) { + return false; + } + size_t count = 0; + for (const auto& type : method.signature.argTypes) { + if (classes.count(type) != 0) { + ++count; + } + } + return signature.argTypes.size() == 1 + count; +} + +static void printMethods( + const std::string& base, + const Kind& classKind, + const std::vector& methods, + const Subs& subs) { + for (auto method : methods) { + if (specificSignatures.count({base, method.name}) != 0) { + for (const auto& signature : specificSignatures.at({base, method.name})) { + if (matches(classKind, signature, method)) { + printMethod(base, classKind, signature, method, subs); + } + } + } else if (specificStaticSignatures.count({base, method.name}) != 0) { + for (const auto& signature : + specificStaticSignatures.at({base, method.name})) { + if (signature.returnType.size() == classKind.size()) { + printMethod(base, classKind, signature, method, subs, true); + } + } + } else if (signatures.count(method.name) != 0) { + for (const auto& signature : signatures.at(method.name)) { + if (matches(classKind, signature, method)) { + if (printMethod(base, classKind, signature, method, subs)) { + break; + } + } + } + } else if (staticSignatures.count(method.name) != 0) { + for (const auto& signature : staticSignatures.at(method.name)) { + if (signature.returnType.size() == classKind.size()) { + printMethod(base, classKind, signature, method, subs, true); + } + } + } else if ( + method.name == "#to" && + classes.count(method.signature.returnType) == 1) { + for (auto returnKind : classes.at(method.signature.returnType).kinds) { + if (classKind.size() == 2 && returnKind.size() == 2) { + printTo(modify_map_unary(), method, subs); + } + } + } else if ( + method.name == "#as" && + classes.count(method.signature.returnType) == 1) { + for (auto returnKind : classes.at(method.signature.returnType).kinds) { + if (classKind.size() == returnKind.size()) { + for (const auto& constructor : + constructors.at(method.signature.returnType)) { + for (const auto& signature : constructor.signatures) { + if (constructor.argTypes[0] == base && + signature.returnType.size() == classKind.size()) { + printAs(signature, method, subs); + } + } + } + } + } + } else if (isForeach(method.name)) { + printForeach(base, classKind, method); + } + } +} + +static void printConstructor( + const std::string& base, + const std::string& className, + const Kind& classKind, + const std::vector& argTypes, + const Signature& signature, + const Subs& subs) { + std::cout << className << "("; + for (size_t i = 0; i < argTypes.size(); ++i) { + if (i > 0) { + std::cout << ", "; + } + std::cout << "const "; + printType(argTypes[i], signature.argTypes[i], subs); + std::cout << "& arg" << i; + } + std::cout << ") : " << base << "("; + for (size_t i = 0; i < argTypes.size(); ++i) { + if (i > 0) { + std::cout << ", "; + } + std::cout << "arg" << i; + } + std::cout << ") {}\n"; +} + +static void printOneDefinition( + const std::string& base, + const std::string& className, + const Kind& classKind, + const Exported& exported, + const Subs& subs) { + std::cout << "\n"; + printClassDeclaration(className, classKind); + std::cout << " : public " << base << " {\n"; + std::cout << className << "() = default;\n"; + std::cout << "explicit " << className << "(const " << base + << "& obj) : " << base << "(obj) {}\n"; + if (constructors.count(base) != 0) { + for (const auto& constructor : constructors.at(base)) { + for (const auto& signature : constructor.signatures) { + if (classKind.size() == signature.returnType.size()) { + printConstructor( + base, + className, + classKind, + constructor.argTypes, + signature, + subs); + } + } + } + } + if (exported.count(base) != 0) { + printMethods(base, classKind, exported.at(base), subs); + } + std::cout << "};\n"; +} + +static void printDefinition( + const std::string& base, + const std::string& className, + const Kind& classKind, + const Exported& exported) { + std::set kinds{classKind}; + if (exported.count(base) != 0) { + for (auto method : exported.at(base)) { + if (specificStaticSignatures.count({base, method.name}) != 0) { + for (const auto& signature : + specificStaticSignatures.at({base, method.name})) { + if (signature.returnType.size() == classKind.size() && + signature.returnType != classKind) { + kinds.emplace(signature.returnType); + } + } + } else if (signatures.count(method.name) != 0) { + for (const auto& signature : signatures.at(method.name)) { + if (matches(classKind, signature, method) && + signature.argTypes[0] != classKind) { + kinds.emplace(signature.argTypes[0]); + } + } + } + } + } + for (auto kind : kinds) { + Subs subs; + for (size_t i = 0; i < classKind.size(); ++i) { + if (classKind[i] != kind[i]) { + subs.emplace(classKind[i].name, kind[i]); + } + } + printOneDefinition(base, className, kind, exported, subs); + } +} + +static void printDefinitions(const Exported& exported) { + printDefinition("space", "Space", params_type(), exported); + printDefinition("set", "Set", params_type(), exported); + for (auto kvp : classes) { + for (auto kind : kvp.second.kinds) { + if (kind.size() == 0 && kvp.first != "space" && kvp.first != "set") { + printDefinition(kvp.first, kvp.second.name, kind, exported); + } + } + } + for (auto kvp : classes) { + for (auto kind : kvp.second.kinds) { + if (kind.size() != 0) { + printDefinition(kvp.first, kvp.second.name, kind, exported); + } + } + } +} + +static std::string extractArg(std::string arg) { + size_t start = 0; + constexpr auto constStr = "const "; + if (arg.find(constStr) != std::string::npos) { + start += strlen(constStr); + } + return dropIslNamespace(arg.substr(0, arg.find(" ", start))); +} + +static std::vector splitArgs(const std::string& args) { + std::vector list; + size_t pos, old = 0; + + while ((pos = args.find(", ", old)) != std::string::npos) { + list.emplace_back(extractArg(args.substr(old, pos))); + old = pos + 2; + } + if (args.length() > 0) { + list.emplace_back(extractArg(args.substr(old))); + } + return list; +} + +int main(int argc, char** argv) { + Exported exported; + for (std::string line; std::getline(std::cin, line);) { + std::regex declaration("^([a-z_:]+) (.*)::([a-z_]+)\\((.*)\\)(.*const)?$"); + std::smatch match; + if (!std::regex_match(line, match, declaration)) { + continue; + } + + auto retType = dropIslNamespace(match[1].str()); + auto className = dropIslNamespace(match[2].str()); + auto name = match[3].str(); + auto args = splitArgs(match[4].str()); + + if (name == "from" && args.size() == 1) { + exported[args[0]].emplace_back(Method{"#to", {retType, args}}); + } + if (signatures.count(name) == 0 && + specificSignatures.count({className, name}) == 0 && + staticSignatures.count(name) == 0 && + specificStaticSignatures.count({className, name}) == 0 && + !isForeach(name)) { + continue; + } + + exported[className].emplace_back(Method{name, {retType, args}}); + } + for (auto kvp : constructors) { + for (const auto& constructor : kvp.second) { + const auto& args = constructor.argTypes; + if (args.size() == 1) { + exported[args[0]].emplace_back(Method{"#as", {kvp.first, args}}); + } + } + } + + std::cout << header; + printForwardDeclarations(); + printDefinitions(exported); + std::cout << footer; + + return EXIT_SUCCESS; +} diff --git a/isl_interface/include/isl/cpp.h b/isl_interface/include/isl/cpp.h index aa9ce0a00..70217ac36 100644 --- a/isl_interface/include/isl/cpp.h +++ b/isl_interface/include/isl/cpp.h @@ -318,6 +318,7 @@ class aff { inline std::string to_str() const; inline isl::aff add(isl::aff aff2) const; + inline isl::aff add_constant(isl::val v) const; inline isl::aff add_constant_si(int v) const; inline isl::aff ceil() const; inline isl::aff div(isl::aff aff2) const; @@ -1704,6 +1705,7 @@ class map { static inline isl::map universe(isl::space space); inline isl::basic_map unshifted_simple_hull() const; inline isl::set wrap() const; + inline isl::map zip() const; typedef isl_map* isl_ptr_t; }; @@ -2166,6 +2168,7 @@ class pw_aff { inline isl::pw_aff neg() const; inline isl::set nonneg_set() const; inline isl::set params() const; + inline isl::set pos_set() const; inline isl::pw_aff project_domain_on_params() const; inline isl::pw_aff pullback(isl::multi_aff ma) const; inline isl::pw_aff pullback(isl::pw_multi_aff pma) const; @@ -3593,6 +3596,18 @@ isl::aff aff::add(isl::aff aff2) const return manage(res); } +isl::aff aff::add_constant(isl::val v) const +{ + if (!ptr || v.is_null()) + throw isl::exception::create(isl_error_invalid, + "NULL input", __FILE__, __LINE__); + options_scoped_set_on_error saved_on_error(get_ctx(), ISL_ON_ERROR_CONTINUE); + auto res = isl_aff_add_constant_val(copy(), v.release()); + if (!res) + throw exception::create_from_last_error(get_ctx()); + return manage(res); +} + isl::aff aff::add_constant_si(int v) const { if (!ptr) @@ -9734,6 +9749,18 @@ isl::set map::wrap() const return manage(res); } +isl::map map::zip() const +{ + if (!ptr) + throw isl::exception::create(isl_error_invalid, + "NULL input", __FILE__, __LINE__); + options_scoped_set_on_error saved_on_error(get_ctx(), ISL_ON_ERROR_CONTINUE); + auto res = isl_map_zip(copy()); + if (!res) + throw exception::create_from_last_error(get_ctx()); + return manage(res); +} + // implementations for isl::map_list isl::map_list manage(__isl_take isl_map_list *ptr) { if (!ptr) @@ -13082,6 +13109,18 @@ isl::set pw_aff::params() const return manage(res); } +isl::set pw_aff::pos_set() const +{ + if (!ptr) + throw isl::exception::create(isl_error_invalid, + "NULL input", __FILE__, __LINE__); + options_scoped_set_on_error saved_on_error(get_ctx(), ISL_ON_ERROR_CONTINUE); + auto res = isl_pw_aff_pos_set(copy()); + if (!res) + throw exception::create_from_last_error(get_ctx()); + return manage(res); +} + isl::pw_aff pw_aff::project_domain_on_params() const { if (!ptr) diff --git a/tc/core/CMakeLists.txt b/tc/core/CMakeLists.txt index 504dea4e6..61ff92d8c 100644 --- a/tc/core/CMakeLists.txt +++ b/tc/core/CMakeLists.txt @@ -32,7 +32,6 @@ add_library( polyhedral/schedule_print.cc polyhedral/schedule_utils.cc polyhedral/scop.cc - polyhedral/separation.cc polyhedral/unroll.cc polyhedral/utils.cc ) @@ -51,9 +50,7 @@ target_link_libraries( tc_version tc_proto ) -if (WITH_BINDINGS) - add_dependencies(tc_core generate_isl_cpp_h) -endif() +add_dependencies(tc_core generate_isl_cpp_h) install( TARGETS tc_core diff --git a/tc/core/halide2isl.cc b/tc/core/halide2isl.cc index dbc7e60f2..fef70e0fe 100644 --- a/tc/core/halide2isl.cc +++ b/tc/core/halide2isl.cc @@ -21,6 +21,7 @@ #include "tc/core/check.h" #include "tc/core/constants.h" #include "tc/core/polyhedral/body.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_isl_conversion.h" #include "tc/core/polyhedral/schedule_transforms.h" #include "tc/core/polyhedral/schedule_tree.h" @@ -80,13 +81,13 @@ SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) { return builder.table; } -isl::aff makeIslAffFromInt(isl::space space, int64_t val) { +isl::AffOn<> makeIslAffFromInt(isl::Space<> space, int64_t val) { isl::val v = isl::val(space.get_ctx(), val); - return isl::aff(isl::local_space(space), v); + return isl::AffOn<>(isl::aff(isl::local_space(space), v)); } -std::vector makeIslAffBoundsFromExpr( - isl::space space, +std::vector> makeIslAffBoundsFromExpr( + isl::Space<> space, const Expr& e, bool allowMin, bool allowMax); @@ -101,9 +102,9 @@ namespace { * x > max(a,max(b,c)) <=> x > a AND x > b AND x > c */ template -inline std::vector -concatAffs(isl::space space, T op, bool allowMin, bool allowMax) { - std::vector result; +inline std::vector> +concatAffs(isl::Space<> space, T op, bool allowMin, bool allowMax) { + std::vector> result; for (const auto& aff : makeIslAffBoundsFromExpr(space, op->a, allowMin, allowMax)) { @@ -129,10 +130,10 @@ concatAffs(isl::space space, T op, bool allowMin, bool allowMax) { * x < a + max(b,c) NOT <=> x < a + b AND x < a + c for negative values. */ template -inline std::vector combineSingleAffs( - isl::space space, +inline std::vector> combineSingleAffs( + isl::Space<> space, T op, - isl::aff (isl::aff::*combine)(isl::aff) const) { + isl::AffOn<> (isl::AffOn<>::*combine)(const isl::AffOn<>&) const) { auto left = makeIslAffBoundsFromExpr(space, op->a, false, false); auto right = makeIslAffBoundsFromExpr(space, op->b, false, false); TC_CHECK_LE(left.size(), 1u); @@ -162,8 +163,8 @@ inline std::vector combineSingleAffs( * If a Halide expression cannot be converted into a list of affine expressions, * return an empty list. */ -std::vector makeIslAffBoundsFromExpr( - isl::space space, +std::vector> makeIslAffBoundsFromExpr( + isl::Space<> space, const Expr& e, bool allowMin, bool allowMax) { @@ -178,7 +179,7 @@ std::vector makeIslAffBoundsFromExpr( if (const Variable* op = e.as()) { isl::id id(space.get_ctx(), op->name); if (space.has_param(id)) { - return {isl::aff::param_on_domain_space(space, id)}; + return {isl::AffOn<>::param_on_domain_space(space, id)}; } LOG(FATAL) << "Variable not found in isl::space: " << space << ": " << op << ": " << op->name << '\n'; @@ -188,13 +189,13 @@ std::vector makeIslAffBoundsFromExpr( } else if (maxOp != nullptr && allowMax) { return concatAffs(space, maxOp, allowMin, allowMax); } else if (const Add* op = e.as()) { - return combineSingleAffs(space, op, &isl::aff::add); + return combineSingleAffs(space, op, &isl::AffOn<>::add); } else if (const Sub* op = e.as()) { - return combineSingleAffs(space, op, &isl::aff::sub); + return combineSingleAffs(space, op, &isl::AffOn<>::sub); } else if (const Mul* op = e.as()) { - return combineSingleAffs(space, op, &isl::aff::mul); + return combineSingleAffs(space, op, &isl::AffOn<>::mul); } else if (const Div* op = e.as
()) { - return combineSingleAffs(space, op, &isl::aff::div); + return combineSingleAffs(space, op, &isl::AffOn<>::div); } else if (const Mod* op = e.as()) { std::vector result; // We cannot span multiple constraints if a modulo operation is involved. @@ -211,7 +212,7 @@ std::vector makeIslAffBoundsFromExpr( return {}; } -isl::aff makeIslAffFromExpr(isl::space space, const Expr& e) { +isl::AffOn<> makeIslAffFromExpr(isl::Space<> space, const Expr& e) { auto list = makeIslAffBoundsFromExpr(space, e, false, false); TC_CHECK_LE(list.size(), 1u) << "Halide expr " << e << " unrolled into more than 1 isl aff" @@ -219,13 +220,13 @@ isl::aff makeIslAffFromExpr(isl::space space, const Expr& e) { // Non-affine if (list.size() == 0) { - return isl::aff(); + return isl::AffOn<>(); } return list[0]; } -isl::space makeParamSpace(isl::ctx ctx, const ParameterVector& params) { - auto space = isl::space(ctx, 0); +isl::Space<> makeParamSpace(isl::ctx ctx, const ParameterVector& params) { + auto space = isl::Space<>(ctx, 0); // set parameter names for (auto p : params) { space = space.add_param(isl::id(ctx, p.name())); @@ -233,19 +234,29 @@ isl::space makeParamSpace(isl::ctx ctx, const ParameterVector& params) { return space; } -isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params) { +isl::Set<> makeParamContext(isl::ctx ctx, const ParameterVector& params) { auto space = makeParamSpace(ctx, params); - auto context = isl::set::universe(space); + auto context = isl::Set<>::universe(space); for (auto p : params) { - isl::aff a(isl::aff::param_on_domain_space(space, isl::id(ctx, p.name()))); - context = context & (a >= 0); + auto a(isl::AffOn<>::param_on_domain_space(space, isl::id(ctx, p.name()))); + context = context & a.asPwAff().nonneg_set(); } return context; } namespace { -isl::map extractAccess( +/* + * Call the domain_map factory method of the isl::MultiAff + * with appropriate template arguments. + */ +template +static isl::MultiAff, Domain> domainMap( + isl::Space space) { + return isl::MultiAff, Domain>::domain_map(space); +} + +isl::Map, Tensor> extractAccess( const IterationDomain& domain, const IRNode* op, const std::string& tensor, @@ -258,16 +269,17 @@ isl::map extractAccess( // to the outer loop iterators) and then convert this set // into a map in terms of the iteration domain. - isl::space paramSpace = domain.paramSpace; + auto paramSpace = domain.paramSpace; isl::id tensorID(paramSpace.get_ctx(), tensor); auto tensorTuple = constructTensorTuple(paramSpace, tensorID, args.size()); auto tensorSpace = tensorTuple.get_space(); // Start with a totally unconstrained set - every point in // the allocation could be accessed. - isl::set access = isl::set::universe(tensorSpace); + auto access = isl::Set::universe(tensorSpace); - auto identity = isl::multi_aff::identity(tensorSpace.map_from_set()); + auto identity = + isl::MultiAff::identity(tensorSpace.map_from_set()); for (size_t i = 0; i < args.size(); i++) { // Then add one equality constraint per dimension to encode the // point in the allocation actually read/written for each point in @@ -277,9 +289,9 @@ isl::map extractAccess( // The coordinate written to in the range ... auto rangePoint = identity.get_aff(i); // ... equals the coordinate accessed as a function of the parameters. - auto domainPoint = halide2isl::makeIslAffFromExpr(paramSpace, args[i]); - if (!domainPoint.is_null()) { - domainPoint = domainPoint.unbind_params_insert_domain(tensorTuple); + auto paramPoint = halide2isl::makeIslAffFromExpr(paramSpace, args[i]); + if (!paramPoint.is_null()) { + auto domainPoint = paramPoint.unbind_params_insert_domain(tensorTuple); access = access.intersect(domainPoint.eq_set(rangePoint)); } } @@ -292,15 +304,16 @@ isl::map extractAccess( std::string tag = "__tc_ref_" + std::to_string(accesses->size()); isl::id tagID(domain.paramSpace.get_ctx(), tag); accesses->emplace(op, tagID); - isl::space domainSpace = map.get_space().domain(); - isl::space tagSpace = domainSpace.params().add_named_tuple_id_ui(tagID, 0); - domainSpace = domainSpace.product(tagSpace).unwrap(); - map = map.preimage_domain(isl::multi_aff::domain_map(domainSpace)); - - return map; + auto domainSpace = map.get_space().domain(); + auto tagSpace = domainSpace.params().add_named_tuple_id_ui(tagID, 0); + auto taggedSpace = domainSpace.product(tagSpace).unwrap(); + return map.preimage_domain(domainMap(taggedSpace)); } -std::pair extractAccesses( +std::pair< + isl::UnionMap, Tensor>, + isl::UnionMap, Tensor>> +extractAccesses( const IterationDomain& domain, const Stmt& s, AccessMap* accesses) { @@ -325,7 +338,7 @@ std::pair extractAccesses( AccessMap* accesses; public: - isl::union_map reads, writes; + isl::UnionMap, Tensor> reads, writes; FindAccesses(const IterationDomain& domain, AccessMap* accesses) : domain(domain), @@ -355,24 +368,26 @@ bool isReductionUpdate(const Provide* op) { * then converted into an expression on that iteration domain * by reinterpreting the parameters as input dimensions. */ -static isl::multi_aff mapToOther( +template +static isl::MultiAff mapToOther( const IterationDomain& iterationDomain, std::unordered_set skip, isl::id id) { auto ctx = iterationDomain.tuple.get_ctx(); - auto list = isl::aff_list(ctx, 0); + auto list = isl::AffListOn(ctx, 0); for (auto id : iterationDomain.tuple.get_id_list()) { if (skip.count(id.get_name()) == 1) { continue; } - auto aff = isl::aff::param_on_domain_space(iterationDomain.paramSpace, id); - aff = aff.unbind_params_insert_domain(iterationDomain.tuple); - list = list.add(aff); + auto aff = + isl::AffOn<>::param_on_domain_space(iterationDomain.paramSpace, id); + list = list.add(aff.unbind_params_insert_domain(iterationDomain.tuple)); } auto domainSpace = iterationDomain.tuple.get_space(); - auto space = domainSpace.params().add_named_tuple_id_ui(id, list.size()); - space = domainSpace.product(space).unwrap(); - return isl::multi_aff(space, list); + auto space = + domainSpace.params().add_named_tuple_id_ui(id, list.size()); + auto productSpace = domainSpace.product(space).unwrap(); + return isl::MultiAff(productSpace, list); } /* @@ -392,7 +407,7 @@ static isl::multi_aff mapToOther( * that all statement instances that belong to the same reduction * write to the same tensor element. */ -isl::union_map extractReduction( +isl::UnionMap extractReduction( const IterationDomain& iterationDomain, const Provide* op, size_t index) { @@ -409,16 +424,19 @@ isl::union_map extractReduction( } finder; if (!isReductionUpdate(op)) { - return isl::union_map::empty(iterationDomain.tuple.get_space().params()); + auto space = iterationDomain.tuple.get_space().params(); + return isl::UnionMap::empty(space); } op->accept(&finder); if (finder.reductionVars.size() == 0) { - return isl::union_map::empty(iterationDomain.tuple.get_space().params()); + auto space = iterationDomain.tuple.get_space().params(); + return isl::UnionMap(isl::union_map::empty(space)); } auto ctx = iterationDomain.tuple.get_ctx(); isl::id id(ctx, kReductionLabel + op->name + "_" + std::to_string(index)); - auto reduction = mapToOther(iterationDomain, finder.reductionVars, id); - return isl::union_map(isl::map(reduction)); + auto reduction = + mapToOther(iterationDomain, finder.reductionVars, id); + return reduction.asMap().asUnionMap(); } /* @@ -458,7 +476,7 @@ onDomains(isl::aff f, isl::union_set domain, const IterationDomainMap& map) { */ isl::schedule makeScheduleTreeHelper( const Stmt& s, - isl::set set, + isl::Set<> set, isl::id_list outer, Body* body, AccessMap* accesses, @@ -472,7 +490,7 @@ isl::schedule makeScheduleTreeHelper( // Construct a variable (affine function) that references // the new parameter. - auto loopVar = isl::aff::param_on_domain_space(space, id); + auto loopVar = isl::AffOn<>::param_on_domain_space(space, id); // Then we add our new loop bound constraints. auto lbs = @@ -527,16 +545,16 @@ isl::schedule makeScheduleTreeHelper( size_t stmtIndex = statements->size(); isl::id id(set.get_ctx(), kStatementLabel + std::to_string(stmtIndex)); statements->emplace(id, op); - auto tupleSpace = isl::space(set.get_ctx(), 0); - tupleSpace = tupleSpace.add_named_tuple_id_ui(id, outer.size()); + auto space = isl::Space<>(set.get_ctx(), 0); + auto tupleSpace = space.add_named_tuple_id_ui(id, outer.size()); IterationDomain iterationDomain; iterationDomain.paramSpace = set.get_space(); - iterationDomain.tuple = isl::multi_id(tupleSpace, outer); + iterationDomain.tuple = isl::MultiId(tupleSpace, outer); domains->emplace(id, iterationDomain); auto domain = set.unbind_params(iterationDomain.tuple); schedule = isl::schedule::from_domain(domain); - isl::union_map newReads, newWrites; + isl::UnionMap, Tensor> newReads, newWrites; std::tie(newReads, newWrites) = extractAccesses(iterationDomain, op, accesses); // A tensor may be involved in multiple reductions. @@ -553,7 +571,9 @@ isl::schedule makeScheduleTreeHelper( return schedule; }; -ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) { +ScheduleTreeAndAccesses makeScheduleTree( + isl::Space<> paramSpace, + const Stmt& s) { ScheduleTreeAndAccesses result; Body body(paramSpace); @@ -562,7 +582,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) { isl::id_list outer(paramSpace.get_ctx(), 0); auto schedule = makeScheduleTreeHelper( s, - isl::set::universe(paramSpace), + isl::Set<>::universe(paramSpace), outer, &body, &result.accesses, diff --git a/tc/core/halide2isl.h b/tc/core/halide2isl.h index b3ecc95ce..f735846a1 100644 --- a/tc/core/halide2isl.h +++ b/tc/core/halide2isl.h @@ -23,6 +23,7 @@ #include #include "tc/core/polyhedral/body.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_tree.h" #include "tc/core/tc2halide.h" #include "tc/external/isl.h" @@ -43,14 +44,14 @@ struct SymbolTable { SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components); /// Make the space of all given parameter values -isl::space makeParamSpace(isl::ctx ctx, const ParameterVector& params); +isl::Space<> makeParamSpace(isl::ctx ctx, const ParameterVector& params); /// Make the parameter set marking all given parameters /// as non-negative. -isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params); +isl::Set<> makeParamContext(isl::ctx ctx, const ParameterVector& params); /// Make a constant-valued affine function over a space. -isl::aff makeIslAffFromInt(isl::space space, int64_t i); +isl::AffOn<> makeIslAffFromInt(isl::Space<> space, int64_t i); // Make an affine function over a space from a Halide Expr. Returns a // null isl::aff if the expression is not affine. Fails if Variable @@ -58,17 +59,17 @@ isl::aff makeIslAffFromInt(isl::space space, int64_t i); // Note that the input space can be either a parameter space or // a set space, but the expression can only reference // the parameters in the space. -isl::aff makeIslAffFromExpr(isl::space space, const Halide::Expr& e); +isl::AffOn<> makeIslAffFromExpr(isl::Space<> space, const Halide::Expr& e); // Iteration domain information associated to a statement identifier. struct IterationDomain { // All parameters active at the point where the iteration domain // was created, including those corresponding to outer loop iterators. - isl::space paramSpace; + isl::Space<> paramSpace; // The identifier tuple corresponding to the iteration domain. // The identifiers in the tuple are the outer loop iterators, // from outermost to innermost. - isl::multi_id tuple; + isl::MultiId tuple; }; typedef std::unordered_map @@ -102,7 +103,7 @@ struct ScheduleTreeAndAccesses { /// Make a schedule tree from a Halide Stmt, along with auxiliary data /// structures describing the memory access patterns. ScheduleTreeAndAccesses makeScheduleTree( - isl::space paramSpace, + isl::Space<> paramSpace, const Halide::Internal::Stmt& s); } // namespace halide2isl diff --git a/tc/core/polyhedral/body.h b/tc/core/polyhedral/body.h index 7461a5cc9..7186cacdd 100644 --- a/tc/core/polyhedral/body.h +++ b/tc/core/polyhedral/body.h @@ -17,6 +17,7 @@ #include +#include "tc/core/polyhedral/domain_types.h" #include "tc/external/isl.h" namespace tc { @@ -26,11 +27,13 @@ namespace polyhedral { struct Body { Body() = default; Body(isl::space paramSpace) { - reductions = writes = reads = isl::union_map::empty(paramSpace); + auto empty = isl::union_map::empty(paramSpace); + writes = reads = isl::UnionMap, Tensor>(empty); + reductions = isl::UnionMap(empty); } // Specialize to the given context. - void specialize(isl::set context) { + void specialize(isl::Set<> context) { reads = reads.intersect_params(context); writes = writes.intersect_params(context); reductions = reductions.intersect_params(context); @@ -39,7 +42,7 @@ struct Body { // Union maps describing the reads and writes done. Uses the ids in // the schedule tree to denote the containing Stmt, and tags each // access with a unique reference id of the form __tc_ref_N. - isl::union_map reads, writes; + isl::UnionMap, Tensor> reads, writes; // A function on reduction update statement instances that partitions them // into individual reductions, where each reduction consists of @@ -73,7 +76,7 @@ struct Body { // That is, in the example above, it would just be // // { S[i] -> R[] : 0 <= i < 4 } - isl::union_map reductions; + isl::UnionMap reductions; }; std::ostream& operator<<(std::ostream& os, const Body& body); diff --git a/tc/core/polyhedral/codegen_llvm.cc b/tc/core/polyhedral/codegen_llvm.cc index c87289d1f..2942b7ff7 100644 --- a/tc/core/polyhedral/codegen_llvm.cc +++ b/tc/core/polyhedral/codegen_llvm.cc @@ -70,7 +70,7 @@ int64_t toSInt(isl::val v) { return v.get_num_si(); } -int64_t getTensorSize(isl::set context, const Halide::Expr& e) { +int64_t getTensorSize(isl::Set<> context, const Halide::Expr& e) { // isl will take care of substituting parameter values if they are known and // simplifying the expression. auto aff = halide2isl::makeIslAffFromExpr(context.get_space(), e); @@ -81,7 +81,7 @@ int64_t getTensorSize(isl::set context, const Halide::Expr& e) { std::vector getTensorSizesWithoutLeadingDim( const Halide::OutputImageParam& t, - isl::set context) { + isl::Set<> context) { auto dims = t.dimensions(); std::vector sizes; sizes.reserve(dims); diff --git a/tc/core/polyhedral/cuda/codegen.cc b/tc/core/polyhedral/cuda/codegen.cc index 606cd235c..1aafa3f75 100644 --- a/tc/core/polyhedral/cuda/codegen.cc +++ b/tc/core/polyhedral/cuda/codegen.cc @@ -723,8 +723,9 @@ void emitMappedTensorAccess( // MA = multi_aff, PMA = pw_multi_aff auto access = makeMultiAffAccess(tensorId, subscripts, context); // MA :: D -> O - auto promotion = promotionInfo.group->promotion(); // MA :: [S -> O] -> P - promotion = promotion.set_range_tuple_id(promotionInfo.groupId); + auto prePromotion = promotionInfo.group->promotion(); // MA :: [S -> O] -> P + auto promotion = + prePromotion.set_range_tuple_id(promotionInfo.groupId); auto iteratorMap = context.iteratorMap(); // PMA :: A -> D auto schedule = isl::map::from(promotionInfo.outerSchedule.intersect_domain( context.domain())); // map :: D -> S diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 02e603b2a..e00a947e5 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -32,6 +32,7 @@ #include "tc/core/polyhedral/cuda/mapping_types.h" #include "tc/core/polyhedral/cuda/memory_promotion_heuristic.h" #include "tc/core/polyhedral/cuda/tighten_launch_bounds.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/exceptions.h" #include "tc/core/polyhedral/schedule_transforms.h" #include "tc/core/polyhedral/schedule_tree_matcher.h" @@ -135,7 +136,9 @@ const CudaDim& mappingSize(const MappedScop* mscop) { // Return a pointer to the updated node (below the inserted filter) // for call chaining purposes. template -ScheduleTree* MappedScop::map(ScheduleTree* tree, isl::union_pw_aff_list list) { +ScheduleTree* MappedScop::map( + ScheduleTree* tree, + isl::UnionPwAffListOn list) { size_t nToMap = list.size(); const auto& extent = mappingSize(this).view; TC_CHECK_LE(nToMap, extent.size()) << "dimension overflow"; @@ -145,7 +148,7 @@ ScheduleTree* MappedScop::map(ScheduleTree* tree, isl::union_pw_aff_list list) { auto universe = domain.universe(); std::vector idList; - auto affList = isl::union_pw_aff_list(list.get_ctx(), 0); + auto affList = isl::UnionPwAffListOn(list.get_ctx(), 0); for (size_t i = 0; i < nToMap; ++i) { auto id = MappingTypeId::makeId(i); auto upa = list.get_at(i); @@ -157,8 +160,8 @@ ScheduleTree* MappedScop::map(ScheduleTree* tree, isl::union_pw_aff_list list) { for (size_t i = nToMap; i < extent.size(); ++i) { auto id = MappingTypeId::makeId(i); - affList = affList.add( - isl::union_pw_aff(universe, isl::val::zero(domain.get_ctx()))); + affList = affList.add(isl::UnionPwAffOn( + universe, isl::val::zero(domain.get_ctx()))); idList.emplace_back(id); } @@ -225,7 +228,10 @@ void fixThreadsBelow(MappedScop& mscop, ScheduleTree* tree, size_t begin) { * Anything that depends on an update statement is ordered after * the update statements. Anything else is ordered before. */ -bool separatedOut(Scop& scop, ScheduleTree* tree, isl::union_set updates) { +bool separatedOut( + Scop& scop, + ScheduleTree* tree, + isl::UnionSet updates) { auto domain = activeDomainPoints(scop.scheduleRoot(), tree); auto other = domain.subtract(updates); if (other.is_empty()) { @@ -295,8 +301,8 @@ bool MappedScop::detectReductions(ScheduleTree* tree) { }); // The outer (coincident) members, together with the prefix schedule, // need to determine a single reduction. - auto prefix = prefixScheduleMupa(schedule(), tree); - prefix = prefix.range_product(band->memberRange(0, nCoincident)); + auto prefix = prefixScheduleMupa(schedule(), tree) + .range_product(band->memberRange(0, nCoincident)); if (!isSingleReductionWithin(updates, prefix, scop())) { return false; } @@ -318,8 +324,8 @@ bool MappedScop::needReductionSeparation(const ScheduleTree* st) { return !reductionBandUpdates_.at(st).separated; } -isl::multi_union_pw_aff MappedScop::reductionMapSchedule( - const ScheduleTree* st) { +isl::MultiUnionPwAff +MappedScop::reductionMapSchedule(const ScheduleTree* st) { TC_CHECK(reductionBandUpdates_.count(st) == 1); auto reductionBand = st->as(); TC_CHECK(reductionBand); @@ -330,7 +336,7 @@ isl::multi_union_pw_aff MappedScop::reductionMapSchedule( TC_CHECK_GE(nMember, reductionDim + 1); auto first = reductionDim + 1 - nMappedThreads; - return reductionBand->memberRange(first, nMappedThreads); + return reductionBand->memberRange(first, nMappedThreads); } ScheduleTree* MappedScop::separateReduction(ScheduleTree* st) { @@ -341,10 +347,10 @@ ScheduleTree* MappedScop::separateReduction(ScheduleTree* st) { auto root = scop_->scheduleRoot(); auto domain = activeDomainPoints(root, st); - auto prefixSchedule = prefixScheduleMupa(root, st); + auto prefixSchedule = prefixScheduleMupa(root, st); auto reductionSchedule = reductionMapSchedule(st); auto space = reductionSchedule.get_space(); - auto size = isl::multi_val::zero(space); + auto size = isl::MultiVal::zero(space); for (size_t i = 0; i < numThreads.view.size(); ++i) { auto pos = numThreads.view.size() - 1 - i; size = size.set_val(pos, isl::val(st->ctx_, numThreads.view[i])); @@ -501,18 +507,19 @@ constexpr auto kWarp = "warp"; * (of size "warpSize") to a warp identifier, * based on the thread sizes s_x, s_y up to s_z in "block". */ -isl::multi_aff constructThreadToWarp( +isl::MultiAff constructThreadToWarp( isl::ctx ctx, const unsigned warpSize, const Block& block) { - auto space = isl::space(ctx, 0); + auto space = isl::Space<>(ctx, 0); auto id = isl::id(ctx, kBlock); - auto blockSpace = space.add_named_tuple_id_ui(id, block.view.size()); - auto warpSpace = space.add_named_tuple_id_ui(isl::id(ctx, kWarp), 1); - auto aff = isl::aff::zero_on_domain(blockSpace); + auto blockSpace = space.add_named_tuple_id_ui(id, block.view.size()); + auto warpSpace = space.add_named_tuple_id_ui(isl::id(ctx, kWarp), 1); + auto aff = isl::AffOn::zero_on_domain(blockSpace); auto nThread = block.view.size(); - auto identity = isl::multi_aff::identity(blockSpace.map_from_set()); + auto identity = + isl::MultiAff::identity(blockSpace.map_from_set()); for (int i = nThread - 1; i >= 0; --i) { aff = aff.scale(isl::val(ctx, block.view[i])); aff = aff.add(identity.get_aff(i)); @@ -520,35 +527,35 @@ isl::multi_aff constructThreadToWarp( aff = aff.scale_down(isl::val(ctx, warpSize)).floor(); auto mapSpace = blockSpace.product(warpSpace).unwrap(); - return isl::multi_aff(mapSpace, isl::aff_list(aff)); + return isl::MultiAff(mapSpace, aff.asAffList()); } } // namespace -isl::multi_union_pw_aff MappedScop::threadMappingSchedule( +isl::MultiUnionPwAff MappedScop::threadMappingSchedule( const ScheduleTree* tree) const { std::vector ids; for (size_t i = 0; i < numThreads.view.size(); ++i) { ids.emplace_back(mapping::ThreadId::makeId(i)); } auto tupleId = isl::id(tree->ctx_, kBlock); - return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId); + return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId); } -isl::multi_union_pw_aff MappedScop::blockMappingSchedule( +isl::MultiUnionPwAff MappedScop::blockMappingSchedule( const ScheduleTree* tree) const { std::vector ids; for (size_t i = 0; i < numBlocks.view.size(); ++i) { ids.emplace_back(mapping::BlockId::makeId(i)); } auto tupleId = isl::id(tree->ctx_, kGrid); - return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId); + return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId); } Scop::SyncLevel MappedScop::findBestSync( ScheduleTree* st1, ScheduleTree* st2, - isl::multi_union_pw_aff domainToThread, - isl::multi_union_pw_aff domainToWarp) { + isl::MultiUnionPwAff domainToThread, + isl::MultiUnionPwAff domainToWarp) { // Active points in the two schedule trees auto stRoot = scop_->scheduleRoot(); auto activePoints1 = activeDomainPointsBelow(stRoot, st1); @@ -862,10 +869,11 @@ void MappedScop::insertMappingContext() { {TX, TX.mappingSize(block)}, {TY, TY.mappingSize(block)}, {TZ, TZ.mappingSize(block)}}; - auto space = scop.domain().universe().get_space(); + auto space = scop.domain().get_space(); auto mappingContext = makeParameterContext( space, mappingIdsWithSizes.begin(), mappingIdsWithSizes.end()); - updateTopLevelContext(scop.scheduleRoot(), mappingContext.from_params()); + updateTopLevelContext( + scop.scheduleRoot(), mappingContext.from_params()); } namespace { @@ -895,7 +903,7 @@ std::unique_ptr makeSpecializedMappedScop( // outer schedule dimensions, so the space of a parameter context code is that // of a zero-dimensional space. auto root = scop->scheduleRoot(); - updateTopLevelContext(root, scop->context().from_params()); + updateTopLevelContext(root, scop->context().from_params()); tc::Grid grid = mappedScop.numBlocks; tc::Block block = mappedScop.numThreads; diff --git a/tc/core/polyhedral/cuda/mapped_scop.h b/tc/core/polyhedral/cuda/mapped_scop.h index 47aef39c6..fc92a2932 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.h +++ b/tc/core/polyhedral/cuda/mapped_scop.h @@ -24,6 +24,7 @@ #include "tc/core/cuda/cuda_mapping_options.h" #include "tc/core/polyhedral/cuda/mapping_types.h" #include "tc/core/polyhedral/cuda/memory_promotion_heuristic.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/scop.h" #include "tc/core/tensor.h" #include "tc/external/isl.h" @@ -150,7 +151,7 @@ class MappedScop { template detail::ScheduleTree* map( detail::ScheduleTree* tree, - isl::union_pw_aff_list list); + isl::UnionPwAffListOn list); // Map "band" to block identifiers and then scale // the band members by "tileSizes". void mapToBlocksAndScaleBand( @@ -171,7 +172,8 @@ class MappedScop { // Return the schedule that will be used by mapInnermostBandsToThreads // for mapping to thread identifiers, with the last function // corresponding to thread identifier x. - isl::multi_union_pw_aff reductionMapSchedule(const detail::ScheduleTree* st); + isl::MultiUnionPwAff reductionMapSchedule( + const detail::ScheduleTree* st); // Separate out reductions that can be mapped to an entire block. // The remaining parts, if any, are no longer considered for replacement // by a library call. @@ -186,8 +188,8 @@ class MappedScop { Scop::SyncLevel findBestSync( detail::ScheduleTree* st1, detail::ScheduleTree* st2, - isl::multi_union_pw_aff domainToThread, - isl::multi_union_pw_aff domainToWarp); + isl::MultiUnionPwAff domainToThread, + isl::MultiUnionPwAff domainToWarp); public: // Find best configuration of synchronizations in a sequence, minimizing @@ -208,14 +210,14 @@ class MappedScop { // to the thread identifiers, where all branches in "tree" // are assumed to have been mapped to thread identifiers. // The result lives in a space of the form block[x, ...]. - isl::multi_union_pw_aff threadMappingSchedule( + isl::MultiUnionPwAff threadMappingSchedule( const detail::ScheduleTree* tree) const; // Extract a mapping from the domain elements active at "tree" // to the block identifiers, where all branches in "tree" // are assumed to have been mapped to block identifiers. // The result lives in a space of the form grid[x, ...]. - isl::multi_union_pw_aff blockMappingSchedule( + isl::MultiUnionPwAff blockMappingSchedule( const detail::ScheduleTree* tree) const; private: diff --git a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc index b1bca6923..2b142763f 100644 --- a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc +++ b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc @@ -19,6 +19,7 @@ #include "tc/core/polyhedral/cuda/mapped_scop.h" #include "tc/core/polyhedral/cuda/mapping_types.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/exceptions.h" #include "tc/core/polyhedral/memory_promotion.h" #include "tc/core/polyhedral/schedule_tree.h" @@ -141,11 +142,11 @@ std::vector collectBranchMarkers(T root, T node) { * In other words, check that the mapping from statement instances * to pairs of outer schedule points and group elements is not injective. */ +template bool hasReuseWithin( const TensorReferenceGroup& group, - isl::multi_union_pw_aff outer) { - auto map = isl::union_map::from(outer); - map = map.range_product(group.originalAccesses()); + isl::MultiUnionPwAff outer) { + auto map = outer.toUnionMap().range_product(group.originalAccesses()); return !map.is_injective(); } @@ -153,9 +154,12 @@ bool hasReuseWithin( * Create a map that increments the "dim"-th dimension and keeps all other * dimensions unchanged. */ -isl::map makeNextElementMap(isl::space setSpace, unsigned dim) { +template +isl::Map makeNextElementMap( + isl::Space setSpace, + unsigned dim) { auto mapSpace = setSpace.map_from_set(); - auto identityMA = isl::multi_aff::identity(mapSpace); + auto identityMA = isl::MultiAff::identity(mapSpace); size_t size = identityMA.size(); if (dim < 0 || dim >= size) { @@ -166,7 +170,7 @@ isl::map makeNextElementMap(isl::space setSpace, unsigned dim) { auto aff = identityMA.get_aff(dim); identityMA = identityMA.set_aff(dim, aff + 1); - return isl::map(identityMA); + return identityMA.asMap(); } /* @@ -225,9 +229,9 @@ bool promotionImprovesCoalescing( auto depth = marker->scheduleDepth(root); auto activePoints = activeDomainPoints(root, mapping); auto localAccesses = originalAccesses.intersect_domain(activePoints); - auto schedule = prefixSchedule(root, marker); + auto schedule = prefixSchedule(root, marker); auto scheduledAccesses = localAccesses.apply_domain(schedule); - for (auto access : isl::UnionAsVector(scheduledAccesses)) { + for (auto access : scheduledAccesses.get_map_list()) { auto scheduleSpace = access.get_space().domain(); auto tensorSpace = access.get_space().range(); auto elementToNext = makeNextElementMap(tensorSpace, tensorDim - 1); @@ -247,13 +251,13 @@ bool promotionImprovesCoalescing( * Returns the union of all mapping filters to "MappingType" in "scop". */ template -isl::union_set collectMappingsTo(const Scop& scop) { +isl::UnionSet collectMappingsTo(const Scop& scop) { auto root = scop.scheduleRoot(); auto domain = scop.domain(); auto mappingFilters = detail::ScheduleTree::collect(root, detail::ScheduleTreeType::Mapping); mappingFilters = functional::Filter(isMappingTo, mappingFilters); - auto mapping = isl::union_set::empty(domain.get_space()); + auto mapping = isl::UnionSet::empty(domain.get_space()); for (auto mf : mappingFilters) { auto filterNode = mf->as(); auto filter = filterNode->filter_.intersect(activeDomainPoints(root, mf)); @@ -292,11 +296,12 @@ isl::union_set collectMappingsTo(const Scop& scop) { * different references may have different values, but all of them remain * independent of non-unrolled loop iterators. */ +template bool accessSubscriptsAreUnrolledLoops( const TensorReferenceGroup& group, const detail::ScheduleTree* root, const detail::ScheduleTree* scope, - isl::multi_union_pw_aff outerSchedule) { + isl::MultiUnionPwAff outerSchedule) { using namespace detail; auto nodes = ScheduleTree::collect(scope); @@ -313,14 +318,14 @@ bool accessSubscriptsAreUnrolledLoops( ancestors.push_back(leaf); auto subdomain = activeDomainPointsBelow(root, leaf); - auto unrolledDims = isl::union_pw_aff_list(leaf->ctx_, 1); + auto unrolledDims = isl::UnionPwAffListOn(leaf->ctx_, 1); for (auto node : ancestors) { - auto band = node->as(); + auto band = node->template as(); if (!band) { continue; } - isl::multi_union_pw_aff schedule = band->mupa_; + auto schedule = band->mupa_; schedule = schedule.intersect_domain(subdomain); for (size_t i = 0, e = band->nMember(); i < e; ++i) { if (!band->unroll_[i]) { @@ -330,9 +335,10 @@ bool accessSubscriptsAreUnrolledLoops( } } - auto space = - subdomain.get_space().add_unnamed_tuple_ui(unrolledDims.size()); - auto unrolledDimsMupa = isl::multi_union_pw_aff(space, unrolledDims); + auto space = subdomain.get_space().template add_unnamed_tuple_ui( + unrolledDims.size()); + auto unrolledDimsMupa = + isl::MultiUnionPwAff(space, unrolledDims); // It is possible that no loops are unrolled, in which case // unrolledDimsMupa is zero-dimensional and needs an explicit domain @@ -341,10 +347,11 @@ bool accessSubscriptsAreUnrolledLoops( unrolledDimsMupa.intersect_domain(group.originalAccesses().domain()); auto accesses = group.originalAccesses(); - auto schedule = outerSchedule.flat_range_product(unrolledDimsMupa); - accesses = accesses.apply_domain(isl::union_map::from(schedule)); + auto schedule = outerSchedule.range_product(unrolledDimsMupa); + auto scheduleMap = schedule.toUnionMap(); + auto scheduledAccesses = accesses.apply_domain(scheduleMap); - if (!accesses.is_single_valued()) { + if (!scheduledAccesses.is_single_valued()) { return false; } } @@ -364,23 +371,24 @@ bool accessSubscriptsAreUnrolledLoops( * thread associated to a given pair of tensor element and outer schedule * iteration. */ +template bool isPromotableToRegistersBelow( const TensorReferenceGroup& group, const detail::ScheduleTree* root, const detail::ScheduleTree* scope, - isl::multi_union_pw_aff outer, - isl::multi_union_pw_aff thread) { + isl::MultiUnionPwAff outer, + isl::MultiUnionPwAff thread) { if (!accessSubscriptsAreUnrolledLoops( - group, root, scope, outer.flat_range_product(thread))) { + group, root, scope, outer.range_product(thread))) { return false; } auto originalAccesses = group.originalAccesses(); - auto map = isl::union_map::from(outer); - map = map.range_product(originalAccesses); - map = map.apply_domain(isl::union_map::from(thread)); + auto outerMap = outer.toUnionMap(); + auto pair = outerMap.range_product(originalAccesses); + auto threadToPair = pair.apply_domain(thread.toUnionMap()); - return map.is_injective(); + return threadToPair.is_injective(); } /* @@ -471,13 +479,13 @@ void promoteToSharedBelow( throw promotion::IncorrectScope("cannot promote below a sequence/set node"); } - auto partialSched = partialSchedule(root, node); + auto partialSched = partialSchedule(root, node); auto mapping = collectMappingsTo(scop); auto groupMap = TensorReferenceGroup::accessedWithin( partialSched.intersect_domain(mapping), scop.body); // Pure affine schedule without (mapping) filters. - auto partialSchedMupa = partialScheduleMupa(root, node); + auto partialSchedMupa = partialScheduleMupa(root, node); // Prepare groups for sorting, to have specified order necessary for // reproducibility and tests. @@ -645,7 +653,7 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) { auto blockMapping = collectMappingsTo(scop); auto mapping = collectMappingsTo(scop).intersect(blockMapping); - auto schedule = partialSchedule(scop.scheduleRoot(), scope); + auto schedule = partialSchedule(scop.scheduleRoot(), scope); auto groupMap = TensorReferenceGroup::accessedWithin( schedule.intersect_domain(mapping), scop.body); @@ -653,7 +661,7 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) { auto blockSchedule = mscop.blockMappingSchedule(mscop.schedule()); // Pure affine schedule without (mapping) filters. - auto partialSchedMupa = partialScheduleMupa(root, scope); + auto partialSchedMupa = partialScheduleMupa(root, scope); // Schedule with block mapping filter. auto partialSched = isl::union_map::from(partialSchedMupa).intersect_domain(blockMapping); @@ -661,7 +669,7 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) { // performed with respect to the block mapping, so append the block schedule. // If the partial schedule contains it already, it will just end up with // identical dimensions without affecting the result of the checks. - partialSchedMupa = partialSchedMupa.flat_range_product(blockSchedule); + auto partialSchedBlockMupa = partialSchedMupa.range_product(blockSchedule); for (auto& tensorGroups : groupMap) { auto tensorId = tensorGroups.first; @@ -675,11 +683,11 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) { continue; } if (!isPromotableToRegistersBelow( - *group, root, scope, partialSchedMupa, threadSchedule)) { + *group, root, scope, partialSchedBlockMupa, threadSchedule)) { continue; } // Check reuse within threads. - auto schedule = partialSchedMupa.flat_range_product(threadSchedule); + auto schedule = partialSchedBlockMupa.range_product(threadSchedule); if (!hasReuseWithin(*group, schedule)) { continue; } diff --git a/tc/core/polyhedral/domain_types.h b/tc/core/polyhedral/domain_types.h new file mode 100644 index 000000000..95f942d87 --- /dev/null +++ b/tc/core/polyhedral/domain_types.h @@ -0,0 +1,17 @@ +namespace tc { +namespace polyhedral { + +struct Band; +struct Prefix; +struct Promoted; +struct Reduction; +struct ReductionSchedule; +struct Statement; +struct Tag; +struct Tensor; +struct Thread; +struct Unrolled; +struct Warp; + +} // namespace polyhedral +} // namespace tc diff --git a/tc/core/polyhedral/memory_promotion.cc b/tc/core/polyhedral/memory_promotion.cc index b1b947c00..84224cd03 100644 --- a/tc/core/polyhedral/memory_promotion.cc +++ b/tc/core/polyhedral/memory_promotion.cc @@ -22,6 +22,7 @@ #include "tc/core/check.h" #include "tc/core/polyhedral/body.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/exceptions.h" #include "tc/core/polyhedral/schedule_tree.h" #include "tc/core/polyhedral/scop.h" @@ -44,26 +45,30 @@ namespace { // D -> O: o_i = f(D) // // by subtracting "offsets" and by dividing the result by "strides". -isl::map removeRangeStrides( - isl::map relation, - isl::multi_val strides, - isl::multi_aff offsets) { +template +isl::Map removeRangeStrides( + isl::Map relation, + isl::MultiVal strides, + isl::MultiAff offsets) { TC_CHECK_EQ(strides.size(), offsets.size()); auto space = relation.get_space(); - auto stridesMA = isl::multi_aff::identity(space.range().map_from_set()); + auto stridesMA = + isl::MultiAff::identity(space.range().map_from_set()); stridesMA = stridesMA / strides; - return relation.sum(isl::map(offsets.neg())).apply_range(isl::map(stridesMA)); + return relation.sum(offsets.neg().asMap()).apply_range(stridesMA.asMap()); } // Compute a box approximation of the range of the given relation, // including the lower bounds, the box sizes, and the strides. // If the range has strides, remove them first. -ScopedFootprint outputRanges(isl::map access) { +ScopedFootprint outputRanges(isl::Map access) { ScopedFootprint footprint; - footprint.strideValues = isl::multi_val::zero(access.get_space().range()); - footprint.strideOffsets = isl::multi_aff::zero(access.get_space()); + footprint.strideValues = + isl::MultiVal::zero(access.get_space().range()); + footprint.strideOffsets = + isl::MultiAff::zero(access.get_space()); int nSubscripts = footprint.strideValues.size(); for (int i = 0; i < nSubscripts; ++i) { @@ -73,10 +78,10 @@ ScopedFootprint outputRanges(isl::map access) { footprint.strideOffsets.set_aff(i, si.get_offset()); } - access = removeRangeStrides( + auto accessNoStrides = removeRangeStrides( access, footprint.strideValues, footprint.strideOffsets); - footprint.box = access.get_range_simple_fixed_box_hull(); + footprint.box = accessNoStrides.get_range_simple_fixed_box_hull(); return footprint; } } // namespace @@ -84,13 +89,13 @@ ScopedFootprint outputRanges(isl::map access) { // Access has the shape :: [S -> ref] -> O // Extract the reference ID, store it separately and simplify the access. std::unique_ptr TensorReferenceGroup::makeSingleton( - isl::map originalAccess, - isl::map scopedAccess, + isl::Map, Tensor> originalAccess, + isl::Map, Tensor> scopedTaggedAccess, AccessType type) { auto ref = std::unique_ptr(new TensorReference); auto refId = - scopedAccess.get_space().domain().unwrap().get_map_range_tuple_id(); - scopedAccess = scopedAccess.domain_factor_domain(); + scopedTaggedAccess.get_space().domain().unwrap().get_map_range_tuple_id(); + auto scopedAccess = scopedTaggedAccess.domain_factor_domain(); ref->originalAccess = originalAccess.domain_factor_domain(); ref->scopedAccess = scopedAccess; ref->type = type; @@ -109,12 +114,15 @@ std::unique_ptr TensorReferenceGroup::makeSingleton( return group; } -isl::map TensorReferenceGroup::approximateScopedAccesses() const { +isl::Map TensorReferenceGroup::approximateScopedAccesses() + const { auto scopedDomain = scopedAccesses().domain(); auto space = approximation.box.get_space(); - auto accessed = isl::map::universe(space).intersect_domain(scopedDomain); + auto accessed = + isl::Map::universe(space).intersect_domain(scopedDomain); - auto identity = isl::multi_aff::identity(space.range().map_from_set()); + auto identity = + isl::MultiAff::identity(space.range().map_from_set()); for (size_t i = 0; i < approximation.dim(); ++i) { auto offset = approximation.lowerBound(i); auto stride = approximation.stride(i); @@ -123,15 +131,15 @@ isl::map TensorReferenceGroup::approximateScopedAccesses() const { auto rhs = identity.get_aff(i); auto lowerBound = offset * stride + strideOffset; auto upperBound = (offset + size) * stride + strideOffset; - auto partial = - (isl::aff_map(lowerBound) <= rhs) & (isl::aff_map(upperBound) > rhs); + auto partial = lowerBound.asPwAff().lt_map((rhs + 1).asPwAff()) & + upperBound.asPwAff().gt_map(rhs.asPwAff()); accessed = accessed & partial; } return accessed; } -isl::multi_aff ScopedFootprint::lowerBounds() const { +isl::MultiAff ScopedFootprint::lowerBounds() const { if (dim() == 0) { throw promotion::PromotionNYI("promotion for scalars"); } @@ -146,20 +154,20 @@ bool TensorReferenceGroup::isReadOnly() const { return result; } -isl::set TensorReferenceGroup::promotedFootprint() const { +isl::Set TensorReferenceGroup::promotedFootprint() const { auto space = scopedAccesses().get_space().range(); auto sizes = approximation.box.get_size(); if (!sizes.get_space().has_equal_tuples(space)) { throw promotion::GroupingError("unexpected dimensionality mismatch"); } - isl::set footprint = isl::set::universe(space); - auto identity = isl::multi_aff::identity(space.map_from_set()); + isl::Set footprint = isl::Set::universe(space); + auto identity = isl::MultiAff::identity(space.map_from_set()); for (size_t i = 0, e = sizes.size(); i < e; ++i) { auto aff = identity.get_aff(i); auto size = sizes.get_val(i); - footprint = - footprint & (isl::aff_set(aff) >= 0) & (isl::aff_set(aff) < size); + footprint = footprint & aff.asPwAff().nonneg_set() & + (size - aff).asPwAff().pos_set(); } return footprint; } @@ -174,14 +182,14 @@ std::vector TensorReferenceGroup::approximationSizes() const { } namespace { -isl::map referenceScopedAccessesImpl( +isl::Map referenceScopedAccessesImpl( const TensorReferenceGroup& group, AccessType type) { if (group.references.size() == 0) { throw promotion::GroupingError("no references in the group"); } - auto accesses = - isl::map::empty(group.references.front()->scopedAccess.get_space()); + auto accesses = isl::Map::empty( + group.references.front()->scopedAccess.get_space()); for (const auto& ref : group.references) { if (ref->type != type) { @@ -202,40 +210,40 @@ isl::set TensorReferenceGroup::readFootprint() const { return referenceScopedAccessesImpl(*this, AccessType::Read).range(); } -isl::map TensorReferenceGroup::scopedWrites() const { +isl::Map TensorReferenceGroup::scopedWrites() const { return referenceScopedAccessesImpl(*this, AccessType::Write); } -isl::map TensorReferenceGroup::scopedReads() const { +isl::Map TensorReferenceGroup::scopedReads() const { return referenceScopedAccessesImpl(*this, AccessType::Read); } namespace { -isl::union_map referenceOriginalAccessesImpl( +isl::UnionMap referenceOriginalAccessesImpl( const TensorReferenceGroup& group, AccessType type) { if (group.references.size() == 0) { throw promotion::GroupingError("no references in the group"); } - auto accesses = isl::union_map::empty( - group.references.front()->originalAccess.get_space()); + auto accesses = isl::UnionMap::empty( + group.references.front()->originalAccess.get_space().params()); for (const auto& ref : group.references) { if (ref->type != type) { continue; } auto current = ref->originalAccess; - accesses = accesses.unite(isl::union_map(current)); + accesses = accesses.unite(current.asUnionMap()); } return accesses; } } // namespace -isl::union_map TensorReferenceGroup::originalWrites() const { +isl::UnionMap TensorReferenceGroup::originalWrites() const { return referenceOriginalAccessesImpl(*this, AccessType::Write); } -isl::union_map TensorReferenceGroup::originalReads() const { +isl::UnionMap TensorReferenceGroup::originalReads() const { return referenceOriginalAccessesImpl(*this, AccessType::Read); } @@ -281,29 +289,28 @@ void joinOverlappingWrites( void addSingletonReferenceGroup( TensorGroups& tensorGroups, isl::id targetTensor, - isl::union_map schedule, - isl::map access, + isl::UnionMap schedule, + isl::Map, Tensor> access, AccessType type) { - auto scopedUnionAccess = isl::union_map(access.curry()); - scopedUnionAccess = scopedUnionAccess.apply_domain(schedule); - auto scopedAccess = isl::map::from(scopedUnionAccess); - scopedAccess = scopedAccess.uncurry(); + auto unionAccess = access.curry().asUnionMap(); + auto scopedUnionAccess = unionAccess.apply_domain(schedule); + auto scopedAccess = scopedUnionAccess.toMap().uncurry(); tensorGroups[targetTensor].push_back( TensorReferenceGroup::makeSingleton(access, scopedAccess, type)); } void addSingletonReferenceGroups( TensorGroups& tensorGroups, - isl::union_map accesses, - isl::union_set domain, - isl::union_map schedule, + isl::UnionMap, Tensor> accesses, + isl::UnionSet domain, + isl::UnionMap schedule, AccessType type) { // access relations have a shape :: [D -> ref] -> O // use currying to isolate the D part before intersecting with the domain // Compute initial groups with single reference per group. std::unordered_set unapproximatable; - for (auto a : isl::UnionAsVector(accesses)) { - if (isl::union_map(a.curry()).intersect_domain(domain).is_empty()) { + for (auto a : accesses.get_map_list()) { + if (a.curry().asUnionMap().intersect_domain(domain).is_empty()) { continue; } @@ -342,7 +349,7 @@ void addSingletonReferenceGroups( // TensorReferenceGroup, with each group potentially containing multiple // references. TensorGroups TensorReferenceGroup::accessedWithin( - isl::union_map outerSchedule, + isl::UnionMap outerSchedule, const Body& body) { TensorGroups tensorGroups; auto domain = outerSchedule.domain(); @@ -369,14 +376,16 @@ TensorGroups TensorReferenceGroup::accessedWithin( // elements of the promoted array get assigned different values of the original // array in different outer loop iterations; it's impossible to project out the // outer schedule dimensions. -isl::multi_aff TensorReferenceGroup::promotion() const { +isl::MultiAff, Tensor> +TensorReferenceGroup::promotion() const { // access space is S -> O - isl::map map = scopedAccesses(); + auto map = scopedAccesses(); auto accessSpace = map.get_space(); // Construct a projection multi-aff in [S -> O] -> S // for further precomposition. - auto originalSpaceInserter = isl::multi_aff::domain_map(accessSpace); + auto originalSpaceInserter = + isl::MultiAff, Prefix>::domain_map(accessSpace); // Lower bounds and offsets space is S -> O; transform into [S -> O] -> O. auto lowerBounds = @@ -384,7 +393,8 @@ isl::multi_aff TensorReferenceGroup::promotion() const { auto offsets = approximation.strideOffsets.pullback(originalSpaceInserter); // Create promotion starting by identity in [S -> O] -> O. - auto original = isl::multi_aff::range_map(accessSpace); + auto original = + isl::MultiAff, Tensor>::range_map(accessSpace); auto promotion = (original - offsets) / approximation.strideValues - lowerBounds; @@ -408,25 +418,26 @@ namespace { // each dimension of the tensor is contrained by the min_aff on the left and // by the min_aff + extent_aff on the right. Intersect this set with the // context of the scop. -isl::set tensorElementsSet(const Scop& scop, isl::id tensorId) { +isl::Set tensorElementsSet(const Scop& scop, isl::id tensorId) { auto halideParameter = scop.findArgument(tensorId).parameter(); auto space = scop.domain().get_space(); auto nDim = halideParameter.dimensions(); auto tensorTuple = constructTensorTuple(space, tensorId, nDim); auto tensorSpace = tensorTuple.get_space(); - auto tensorElements = isl::set::universe(tensorSpace); - auto identity = isl::multi_aff::identity(tensorSpace.map_from_set()); + auto tensorElements = isl::Set::universe(tensorSpace); + auto identity = + isl::MultiAff::identity(tensorSpace.map_from_set()); for (int i = 0; i < nDim; ++i) { auto minAff = halide2isl::makeIslAffFromExpr( space, halideParameter.min_constraint(i)); auto extentAff = halide2isl::makeIslAffFromExpr( space, halideParameter.extent_constraint(i)); - minAff = minAff.unbind_params_insert_domain(tensorTuple); - extentAff = extentAff.unbind_params_insert_domain(tensorTuple); + auto minAff2 = minAff.unbind_params_insert_domain(tensorTuple); + auto extentAff2 = extentAff.unbind_params_insert_domain(tensorTuple); auto aff = identity.get_aff(i); - tensorElements = tensorElements & (minAff <= isl::aff_set(aff)) & - (isl::aff_set(aff) < (minAff + extentAff)); + tensorElements = tensorElements & (minAff2.le_set(aff)) & + (aff.lt_set(minAff2 + extentAff2)); } tensorElements = tensorElements.intersect_params(scop.context()); @@ -440,11 +451,12 @@ isl::set tensorElementsSet(const Scop& scop, isl::id tensorId) { * Note that this function drops the name of the target space of "schedule", * but this space is irrelevant for the caller. */ -isl::multi_aff dropDummyTensorDimensions( - isl::multi_aff schedule, +template +isl::MultiAff dropDummyTensorDimensions( + isl::MultiAff schedule, const Scop::PromotedDecl& decl) { auto list = schedule.get_aff_list(); - auto space = schedule.get_space().domain(); + auto domainSpace = schedule.get_space().domain(); auto n = list.size(); for (int i = n - 1; i >= 0; --i) { @@ -453,8 +465,8 @@ isl::multi_aff dropDummyTensorDimensions( } } - space = space.add_unnamed_tuple_ui(list.size()); - return isl::multi_aff(space, list); + auto space = domainSpace.template add_unnamed_tuple_ui(list.size()); + return isl::MultiAff(space, list); } inline void unrollAllMembers(detail::ScheduleTreeBand* band) { @@ -478,23 +490,30 @@ ScheduleTree* insertCopiesUnder( // Take the set of all tensor elements. auto tensorElements = tensorElementsSet(scop, tensorId); - auto promotion = isl::map(group.promotion()).set_range_tuple_id(groupId); + auto promotion = + group.promotion().asMap().set_range_tuple_id(groupId); auto promotionSpace = promotion.get_space(); - auto identityCopySchedule = - isl::multi_aff::identity(promotionSpace.range().map_from_set()); + auto identityCopySchedule = isl::MultiAff::identity( + promotionSpace.range().map_from_set()); // Only iterate over significant tensor dimensions. auto decl = scop.promotedDecl(groupId); identityCopySchedule = dropDummyTensorDimensions(identityCopySchedule, decl); - auto readSpace = promotionSpace.wrap().set_set_tuple_id(readId); - auto writeSpace = promotionSpace.wrap().set_set_tuple_id(writeId); + auto readSpace = promotionSpace.wrap().set_set_tuple_id(readId); + auto writeSpace = promotionSpace.wrap().set_set_tuple_id(writeId); auto readSchedule = isl::multi_union_pw_aff(identityCopySchedule.pullback( - isl::multi_aff::wrapped_range_map(readSpace))); + isl::MultiAff< + isl::NamedPair, Promoted>, + Promoted>::wrapped_range_map(readSpace))); auto writeSchedule = isl::multi_union_pw_aff(identityCopySchedule.pullback( - isl::multi_aff::wrapped_range_map(writeSpace))); + isl::MultiAff< + isl::NamedPair, Promoted>, + Promoted>::wrapped_range_map(writeSpace))); - auto readBandNode = ScheduleTree::makeBand(readSchedule); - auto writeBandNode = ScheduleTree::makeBand(writeSchedule); + auto readBandNode = ScheduleTree::makeBand( + isl::MultiUnionPwAff(readSchedule)); + auto writeBandNode = ScheduleTree::makeBand( + isl::MultiUnionPwAff(writeSchedule)); if (unrollAllCopies) { unrollAllMembers(readBandNode->as()); @@ -508,20 +527,22 @@ ScheduleTree* insertCopiesUnder( // control flow, but we should only write back elements that are actually // written to. In any case, intersect the footprint with the set of existing // tensor elements. - auto promotedFootprint = group.promotedFootprint().set_tuple_id(groupId); + auto promotedFootprint = + group.promotedFootprint().set_tuple_id(groupId); auto scheduleUniverse = - isl::set::universe(promotionSpace.domain().unwrap().domain()); + isl::Set::universe(promotionSpace.domain().unwrap().domain()); auto arrayId = promotionSpace.domain().unwrap().get_map_range_tuple_id(); auto approximatedRead = group.approximateScopedAccesses().intersect_range(tensorElements).wrap(); - approximatedRead = approximatedRead.product(promotedFootprint); + auto product = approximatedRead.product(promotedFootprint); auto readExtension = - extension.intersect_range(approximatedRead).set_range_tuple_id(readId); - auto writtenElements = - group.scopedWrites().intersect_range(tensorElements).wrap(); - writtenElements = writtenElements.product(promotedFootprint); - auto writeExtension = - extension.intersect_range(writtenElements).set_range_tuple_id(writeId); + extension.intersect_range(product).set_range_tuple_id(readId); + auto writtenElements = group.scopedWrites() + .intersect_range(tensorElements) + .wrap() + .product(promotedFootprint); + auto writeExtension = extension.intersect_range(writtenElements) + .set_range_tuple_id(writeId); auto readFilterNode = ScheduleTree::makeFilter( isl::set::universe(readExtension.get_space().range()), diff --git a/tc/core/polyhedral/memory_promotion.h b/tc/core/polyhedral/memory_promotion.h index 8ada806aa..236c391d7 100644 --- a/tc/core/polyhedral/memory_promotion.h +++ b/tc/core/polyhedral/memory_promotion.h @@ -17,6 +17,7 @@ #include +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_tree.h" #include "tc/core/polyhedral/scop.h" #include "tc/external/isl.h" @@ -47,21 +48,21 @@ struct ScopedFootprint { isl::val size(size_t pos) const { return box.get_size().get_val(pos); } - isl::aff lowerBound(size_t pos) const { + isl::AffOn lowerBound(size_t pos) const { return box.get_offset().get_aff(pos); } isl::val stride(size_t pos) const { return strideValues.get_val(pos); } - isl::aff strideOffset(size_t pos) const { + isl::AffOn strideOffset(size_t pos) const { return strideOffsets.get_aff(pos); } - isl::fixed_box box; - isl::multi_val strideValues; - isl::multi_aff strideOffsets; + isl::FixedBox box; + isl::MultiVal strideValues; + isl::MultiAff strideOffsets; - isl::multi_aff lowerBounds() const; + isl::MultiAff lowerBounds() const; }; // Descriptor of tensor reference in a Scop. @@ -80,11 +81,11 @@ class TensorReference { public: // Original access relation in terms of the Scop domain. - isl::map originalAccess; + isl::Map originalAccess; // Access relation in terms of partial schedule at the point where the // reference group is introduced in the tree. - isl::map scopedAccess; + isl::Map scopedAccess; // Access direction (read or write). AccessType type; @@ -112,7 +113,7 @@ class TensorReferenceGroup { public: static TensorGroups accessedWithin( - isl::union_map outerSchedule, + isl::UnionMap outerSchedule, const Body& body); bool isReadOnly() const; @@ -125,27 +126,27 @@ class TensorReferenceGroup { } // Access relations in terms of partial schedule of the scoping point. - isl::map scopedWrites() const; - isl::map scopedReads() const; - isl::map scopedAccesses() const { + isl::Map scopedWrites() const; + isl::Map scopedReads() const; + isl::Map scopedAccesses() const { return scopedWrites().unite(scopedReads()); } // Access relations in terms of Scop domain elements. // The resulting union relations have different domain spaces but identical // range spaces. - isl::union_map originalWrites() const; - isl::union_map originalReads() const; - isl::union_map originalAccesses() const { + isl::UnionMap originalWrites() const; + isl::UnionMap originalReads() const; + isl::UnionMap originalAccesses() const { return originalWrites().unite(originalReads()); } // Rectangular overapproximation of the set of tensor elements accessed below // and relative to the scoping point. - isl::map approximateScopedAccesses() const; + isl::Map approximateScopedAccesses() const; - isl::multi_aff promotion() const; - isl::set promotedFootprint() const; + isl::MultiAff, Tensor> promotion() const; + isl::Set promotedFootprint() const; std::vector approximationSizes() const; @@ -155,8 +156,8 @@ class TensorReferenceGroup { std::unique_ptr&& g1, std::unique_ptr&& g2); static std::unique_ptr makeSingleton( - isl::map originalAccess, - isl::map scopedAccess, + isl::Map, Tensor> originalAccess, + isl::Map, Tensor> scopedAccess, AccessType type); public: diff --git a/tc/core/polyhedral/reduction_matcher.cc b/tc/core/polyhedral/reduction_matcher.cc index 8b21f4023..318e9bd6f 100644 --- a/tc/core/polyhedral/reduction_matcher.cc +++ b/tc/core/polyhedral/reduction_matcher.cc @@ -18,6 +18,7 @@ #include #include "tc/core/check.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_tree.h" #include "tc/core/polyhedral/scop.h" #include "tc/external/isl.h" @@ -54,10 +55,12 @@ bool isSupportedReductionUpdateId(isl::id id, const Scop& scop) { } // namespace -isl::union_set reductionUpdates(isl::union_set domain, const Scop& scop) { +isl::UnionSet reductionUpdates( + isl::UnionSet domain, + const Scop& scop) { domain = scop.body.reductions.intersect_domain(domain).domain(); - auto update = isl::union_set::empty(domain.get_space()); - domain.foreach_set([&update, &scop](isl::set set) { + auto update = isl::UnionSet::empty(domain.get_space()); + domain.foreach_set([&update, &scop](isl::Set set) { auto setId = set.get_tuple_id(); if (isSupportedReductionUpdateId(setId, scop)) { update = update.unite(set); @@ -66,16 +69,5 @@ isl::union_set reductionUpdates(isl::union_set domain, const Scop& scop) { return update; } -bool isSingleReductionWithin( - isl::union_set domain, - isl::multi_union_pw_aff prefix, - const Scop& scop) { - auto reductions = scop.body.reductions; - reductions = reductions.intersect_domain(domain); - auto prefixMap = isl::union_map::from(prefix); - auto prefixToReduction = reductions.apply_domain(prefixMap); - return prefixToReduction.is_single_valued(); -} - } // namespace polyhedral } // namespace tc diff --git a/tc/core/polyhedral/schedule_isl_conversion.cc b/tc/core/polyhedral/schedule_isl_conversion.cc index 2b01979b9..28f1fa1de 100644 --- a/tc/core/polyhedral/schedule_isl_conversion.cc +++ b/tc/core/polyhedral/schedule_isl_conversion.cc @@ -23,6 +23,7 @@ #include "tc/core/check.h" #include "tc/core/flags.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_transforms.h" #include "tc/external/isl.h" @@ -81,7 +82,7 @@ isl::schedule_node insertBranch( */ std::vector findCorePositions( const ScheduleTree* st, - isl::union_set domain) { + isl::UnionSet domain) { std::vector positions; TC_CHECK(st->as()); for (size_t i = 0; i < st->numChildren(); ++i) { @@ -125,7 +126,7 @@ isl::schedule_node insertExtension( isl::schedule_node node, const ScheduleTree* st) { auto depth0 = node.get_tree_depth(); - auto domain = node.get_universe_domain(); + auto domain = isl::UnionSet(node.get_universe_domain()); auto child = st->child({0}); auto corePos = findCorePositions(child, domain); TC_CHECK(!corePos.empty()); @@ -242,8 +243,8 @@ std::unique_ptr fromIslScheduleNodeBand( for (size_t i = 0; i < n; ++i) { coincident[i] = b.member_get_coincident(i); } - return ScheduleTreeBand::make( - b.get_partial_schedule(), b.get_permutable(), coincident, unroll); + auto mupa = isl::MultiUnionPwAff(b.get_partial_schedule()); + return ScheduleTreeBand::make(mupa, b.get_permutable(), coincident, unroll); } std::unique_ptr elemFromIslScheduleNode(isl::schedule_node node) { @@ -251,7 +252,7 @@ std::unique_ptr elemFromIslScheduleNode(isl::schedule_node node) { if (auto band = node.as()) { return fromIslScheduleNodeBand(band); } else if (auto context = node.as()) { - auto c = context.get_context(); + auto c = isl::Set(context.get_context()); return ScheduleTreeContext::make(c); } else if (auto domain = node.as()) { auto c = domain.get_domain(); diff --git a/tc/core/polyhedral/schedule_transforms.cc b/tc/core/polyhedral/schedule_transforms.cc index 2aee2e095..9bdc99f28 100644 --- a/tc/core/polyhedral/schedule_transforms.cc +++ b/tc/core/polyhedral/schedule_transforms.cc @@ -31,6 +31,7 @@ #include "tc/core/check.h" #include "tc/core/constants.h" #include "tc/core/functional.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/mapping_types.h" #include "tc/core/polyhedral/schedule_tree_elem.h" #include "tc/core/polyhedral/schedule_tree_matcher.h" @@ -88,7 +89,8 @@ ScheduleTree* joinBandsHelper(ScheduleTree* st, bool& moveChildren) { auto& partialSchedule = eb->mupa_; auto& partialScheduleChild = ebChild->mupa_; - partialSchedule = partialSchedule.flat_range_product(partialScheduleChild); + partialSchedule = + partialSchedule.flat_range_product(partialScheduleChild); eb->coincident_.resize( eb->coincident_.size() + ebChild->coincident_.size(), false); eb->unroll_.insert( @@ -284,7 +286,9 @@ ScheduleTree* insertTopLevelEmptyBand(ScheduleTree* root) { return insertNodeBelow(node, ScheduleTree::makeEmptyBand(root)); } -void updateTopLevelContext(detail::ScheduleTree* root, isl::set context) { +void updateTopLevelContext( + detail::ScheduleTree* root, + isl::Set context) { if (!matchOne(tc::polyhedral::domain(tc::polyhedral::context(any())), root)) { root->appendChild( ScheduleTree::makeContext(context, root->detachChildren())); @@ -362,19 +366,21 @@ detail::ScheduleTree* insertEmptyExtensionAbove( * Construct an extension map for a zero-dimensional statement * with the given identifier. */ -isl::map labelExtension(ScheduleTree* root, ScheduleTree* tree, isl::id id) { - auto prefix = prefixScheduleMupa(root, tree); +isl::Map +labelExtension(ScheduleTree* root, ScheduleTree* tree, isl::id id) { + auto prefix = prefixScheduleMupa(root, tree); auto scheduleSpace = prefix.get_space(); - auto space = scheduleSpace.params().add_named_tuple_id_ui(id, 0); + auto space = scheduleSpace.params().add_named_tuple_id_ui(id, 0); auto extensionSpace = scheduleSpace.map_from_domain_and_range(space); - return isl::map::universe(extensionSpace); + return isl::Map::universe(extensionSpace); } /* * Construct a filter node for a zero-dimensional extension statement * with the given extension map. */ -ScheduleTreeUPtr labelFilterFromExtension(isl::map extension) { +ScheduleTreeUPtr labelFilterFromExtension( + isl::Map extension) { return detail::ScheduleTree::makeFilter(extension.range()); } @@ -391,7 +397,7 @@ void insertExtensionAt( ScheduleTree* relativeRoot, ScheduleTree* seqNode, size_t pos, - isl::union_map extension, + isl::UnionMap extension, ScheduleTreeUPtr&& filterNode) { auto extensionTree = seqNode->ancestor(relativeRoot, 1); auto extensionNode = extensionTree->as(); @@ -410,7 +416,7 @@ void insertExtensionBefore( const ScheduleTree* root, ScheduleTree* relativeRoot, ScheduleTree* tree, - isl::union_map extension, + isl::UnionMap extension, ScheduleTreeUPtr&& filterNode) { size_t pos; auto parent = tree->ancestor(relativeRoot, 1); @@ -439,7 +445,7 @@ void insertExtensionAfter( const ScheduleTree* root, ScheduleTree* relativeRoot, ScheduleTree* tree, - isl::union_map extension, + isl::UnionMap extension, ScheduleTreeUPtr&& filterNode) { size_t pos; auto parent = tree->ancestor(relativeRoot, 1); @@ -501,7 +507,7 @@ namespace { * of band node partial schedules. * Elements of a sequence that end up with an empty filter are removed. */ -void gist(ScheduleTree* tree, isl::union_set context) { +void gist(ScheduleTree* tree, isl::UnionSet context) { if (auto bandElem = tree->as()) { bandElem->mupa_ = bandElem->mupa_.gist(context); } else if (auto filterElem = tree->as()) { @@ -531,7 +537,9 @@ void gist(ScheduleTree* tree, isl::union_set context) { * Create a filter node with the given filter and single child node, * after simplifying the child node in the context of the filter. */ -ScheduleTreeUPtr gistedFilter(isl::union_set filter, ScheduleTreeUPtr child) { +ScheduleTreeUPtr gistedFilter( + isl::UnionSet filter, + ScheduleTreeUPtr child) { gist(child.get(), filter); return ScheduleTree::makeFilter(filter, std::move(child)); } @@ -542,19 +550,20 @@ ScheduleTreeUPtr gistedFilter(isl::union_set filter, ScheduleTreeUPtr child) { * without violating any of the (active) "dependences"? */ bool canOrder( - isl::union_set first, - isl::union_set second, - isl::union_map dependences) { + isl::UnionSet first, + isl::UnionSet second, + isl::UnionMap dependences) { if (first.is_empty() || second.is_empty()) { return true; } // Create an ordering schedule function first -> 0; second -> 1. auto ctx = dependences.get_ctx(); - auto space = isl::space(ctx, 0).add_unnamed_tuple_ui(1); - auto zero = isl::multi_val::zero(space); + auto space = isl::Space<>(ctx, 0).add_unnamed_tuple_ui(1); + auto zero = isl::MultiVal::zero(space); auto one = zero.set_val(0, isl::val::one(ctx)); - auto order = isl::multi_union_pw_aff(first, zero); - order = order.union_add(isl::multi_union_pw_aff(second, one)); + auto order = isl::MultiUnionPwAff(first, zero); + order = order.union_add( + isl::MultiUnionPwAff(second, one)); // Check that this ordering preserves all dependences. auto preserved = dependences.lex_lt_at(order).unite(dependences.eq_at(order)); @@ -566,26 +575,29 @@ bool canOrder( bool canOrderBefore( ScheduleTree* root, ScheduleTree* tree, - isl::union_set filter, - isl::union_map dependences) { - auto other = activeDomainPoints(root, tree).subtract(filter); + isl::UnionSet filter, + isl::UnionMap dependences) { + auto active = activeDomainPoints(root, tree); + auto other = active.subtract(filter); return canOrder(filter, other, dependences); } bool canOrderAfter( ScheduleTree* root, ScheduleTree* tree, - isl::union_set filter, - isl::union_map dependences) { - auto other = activeDomainPoints(root, tree).subtract(filter); + isl::UnionSet filter, + isl::UnionMap dependences) { + auto active = activeDomainPoints(root, tree); + auto other = active.subtract(filter); return canOrder(other, filter, dependences); } void orderBefore( ScheduleTree* root, ScheduleTree* tree, - isl::union_set filter) { - auto other = activeDomainPoints(root, tree).subtract(filter); + isl::UnionSet filter) { + auto active = activeDomainPoints(root, tree); + auto other = active.subtract(filter); auto seq = ScheduleTree::makeSequence( gistedFilter(filter, ScheduleTree::makeScheduleTree(*tree))); auto parent = tree->ancestor(root, 1); @@ -594,8 +606,12 @@ void orderBefore( parent->insertChild(childPos, std::move(seq)); } -void orderAfter(ScheduleTree* root, ScheduleTree* tree, isl::union_set filter) { - auto other = activeDomainPoints(root, tree).subtract(filter); +void orderAfter( + ScheduleTree* root, + ScheduleTree* tree, + isl::UnionSet filter) { + auto active = activeDomainPoints(root, tree); + auto other = active.subtract(filter); auto seq = ScheduleTree::makeSequence( gistedFilter(filter, ScheduleTree::makeScheduleTree(*tree))); auto parent = tree->ancestor(root, 1); diff --git a/tc/core/polyhedral/schedule_transforms.h b/tc/core/polyhedral/schedule_transforms.h index add33d7f2..4fa1f71d2 100644 --- a/tc/core/polyhedral/schedule_transforms.h +++ b/tc/core/polyhedral/schedule_transforms.h @@ -22,6 +22,7 @@ #include #include "tc/core/functional.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/mapping_types.h" #include "tc/core/polyhedral/options.h" #include "tc/core/polyhedral/schedule_tree.h" @@ -110,7 +111,9 @@ detail::ScheduleTree* insertTopLevelEmptyBand(detail::ScheduleTree* root); // Update the top-level context node by intersecting it with "context". The // top-level context node must be located directly under the root of the tree. // If there is no such node, insert one first. -void updateTopLevelContext(detail::ScheduleTree* root, isl::set context); +void updateTopLevelContext( + detail::ScheduleTree* root, + isl::Set context); // In a tree starting at "root", insert a sequence node with // as only child the node identified by "tree". @@ -174,7 +177,7 @@ void insertExtensionBefore( const detail::ScheduleTree* root, detail::ScheduleTree* relativeRoot, detail::ScheduleTree* tree, - isl::union_map extension, + isl::UnionMap extension, ScheduleTreeUPtr&& filterNode); // Insert an extension with the given extension map and extension filter node @@ -189,7 +192,7 @@ void insertExtensionAfter( const detail::ScheduleTree* root, detail::ScheduleTree* relativeRoot, detail::ScheduleTree* tree, - isl::union_map extension, + isl::UnionMap extension, ScheduleTreeUPtr&& filterNode); // Given a sequence node in the schedule tree, insert @@ -233,29 +236,29 @@ void insertExtensionLabelAfter( bool canOrderBefore( detail::ScheduleTree* root, detail::ScheduleTree* tree, - isl::union_set filter, - isl::union_map dependences); + isl::UnionSet filter, + isl::UnionMap dependences); // Is it possible to order the elements in the given filter // after the other active elements without violating // any of the given dependences? bool canOrderAfter( detail::ScheduleTree* root, detail::ScheduleTree* tree, - isl::union_set filter, - isl::union_map dependences); + isl::UnionSet filter, + isl::UnionMap dependences); // Insert a sequence to ensure that the active domain elements // in the given filter are executed before the other active domain elements. void orderBefore( detail::ScheduleTree* root, detail::ScheduleTree* tree, - isl::union_set filter); + isl::UnionSet filter); // Insert a sequence to ensure that the active domain elements // in the given filter are executed after the other active domain elements. void orderAfter( detail::ScheduleTree* root, detail::ScheduleTree* tree, - isl::union_set filter); + isl::UnionSet filter); } // namespace polyhedral } // namespace tc diff --git a/tc/core/polyhedral/schedule_tree.cc b/tc/core/polyhedral/schedule_tree.cc index c17c2542a..fe5763cb2 100644 --- a/tc/core/polyhedral/schedule_tree.cc +++ b/tc/core/polyhedral/schedule_tree.cc @@ -30,6 +30,7 @@ #include "tc/core/check.h" #include "tc/core/constants.h" #include "tc/core/functional.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_tree_elem.h" #include "tc/core/scope_guard.h" #include "tc/external/isl.h" @@ -199,7 +200,7 @@ size_t ScheduleTree::scheduleDepth(const ScheduleTree* relativeRoot) const { } std::unique_ptr ScheduleTree::makeBand( - isl::multi_union_pw_aff mupa, + isl::MultiUnionPwAff mupa, std::vector&& children) { std::vector coincident(mupa.size(), false); std::vector unroll(mupa.size(), false); @@ -211,9 +212,9 @@ std::unique_ptr ScheduleTree::makeBand( ScheduleTreeUPtr ScheduleTree::makeEmptyBand(const ScheduleTree* root) { auto domain = root->as(); TC_CHECK(domain); - auto space = domain->domain_.get_space().set_from_params(); - auto mv = isl::multi_val::zero(space); - auto zero = isl::multi_union_pw_aff(domain->domain_, mv); + auto space = domain->domain_.get_space().add_unnamed_tuple_ui(0); + auto mv = isl::MultiVal::zero(space); + auto zero = isl::MultiUnionPwAff(domain->domain_, mv); return ScheduleTree::makeBand(zero); } @@ -224,7 +225,7 @@ std::unique_ptr ScheduleTree::makeDomain( } std::unique_ptr ScheduleTree::makeContext( - isl::set context, + isl::Set context, std::vector&& children) { return ScheduleTreeContext::make(context, std::move(children)); } @@ -237,7 +238,7 @@ std::unique_ptr ScheduleTree::makeFilter( std::unique_ptr ScheduleTree::makeMappingUnsafe( const std::vector& mappedIds, - isl::union_pw_aff_list mappedAffs, + isl::UnionPwAffListOn mappedAffs, std::vector&& children) { TC_CHECK_EQ(mappedIds.size(), static_cast(mappedAffs.size())) << "expected as many mapped ids as affs"; diff --git a/tc/core/polyhedral/schedule_tree.h b/tc/core/polyhedral/schedule_tree.h index 403dd0210..485d2fb65 100644 --- a/tc/core/polyhedral/schedule_tree.h +++ b/tc/core/polyhedral/schedule_tree.h @@ -21,6 +21,7 @@ #include #include "tc/core/check.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/mapping_types.h" #include "tc/core/polyhedral/options.h" #include "tc/core/utils/vararg.h" @@ -280,7 +281,7 @@ struct ScheduleTree { // Factory functions // static ScheduleTreeUPtr makeBand( - isl::multi_union_pw_aff mupa, + isl::MultiUnionPwAff mupa, std::vector&& children = {}); // Return a zero-dimensional band for use in a tree with the given root. @@ -291,7 +292,7 @@ struct ScheduleTree { std::vector&& children = {}); static ScheduleTreeUPtr makeContext( - isl::set context, + isl::Set context, std::vector&& children = {}); static ScheduleTreeUPtr makeFilter( @@ -301,7 +302,7 @@ struct ScheduleTree { template static inline ScheduleTreeUPtr makeMapping( const std::vector& mappedIds, - isl::union_pw_aff_list mappedAffs, + isl::UnionPwAffListOn mappedAffs, std::vector&& children = {}) { static_assert( std::is_base_of::value, @@ -318,7 +319,7 @@ struct ScheduleTree { // Internal type-unsafe function to construct mappings. static ScheduleTreeUPtr makeMappingUnsafe( const std::vector& mappedIds, - isl::union_pw_aff_list mappedAffs, + isl::UnionPwAffListOn mappedAffs, std::vector&& children); public: @@ -332,7 +333,7 @@ struct ScheduleTree { template static ScheduleTreeUPtr makeBand( - isl::multi_union_pw_aff mupa, + isl::MultiUnionPwAff mupa, Args&&... args) { return makeBand( mupa, vectorFromArgs(std::forward(args)...)); @@ -345,7 +346,7 @@ struct ScheduleTree { } template - static ScheduleTreeUPtr makeContext(isl::set context, Args&&... args) { + static ScheduleTreeUPtr makeContext(isl::Set<> context, Args&&... args) { return makeContext( context, vectorFromArgs(std::forward(args)...)); } diff --git a/tc/core/polyhedral/schedule_tree_elem.cc b/tc/core/polyhedral/schedule_tree_elem.cc index 77db24859..a0a244610 100644 --- a/tc/core/polyhedral/schedule_tree_elem.cc +++ b/tc/core/polyhedral/schedule_tree_elem.cc @@ -26,6 +26,7 @@ #include "tc/core/check.h" #include "tc/core/constants.h" #include "tc/core/flags.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_isl_conversion.h" #include "tc/core/polyhedral/schedule_tree.h" #include "tc/core/scope_guard.h" @@ -144,7 +145,8 @@ ScheduleTreeMapping::ScheduleTreeMapping( auto id = kvp.first; // Create mapping filter by equating the // parameter mappedIds[i] to the "i"-th affine function. - upa = upa.sub(isl::union_pw_aff::param_on_domain(domain.universe(), id)); + upa = upa.sub( + isl::UnionPwAffOn::param_on_domain(domain.universe(), id)); filter_ = filter_.intersect(upa.zero_union_set()); } } @@ -184,7 +186,7 @@ std::unique_ptr ScheduleTreeSet::make( } std::unique_ptr ScheduleTreeBand::make( - isl::multi_union_pw_aff mupa, + isl::MultiUnionPwAff mupa, bool permutable, std::vector coincident, std::vector unroll, @@ -235,10 +237,10 @@ void ScheduleTreeBand::drop(size_t pos, size_t n) { auto nBegin = nMember(); auto list = mupa_.get_union_pw_aff_list(); - auto space = mupa_.get_space().domain(); + auto space = mupa_.get_space().params(); list = list.drop(pos, n); - space = space.add_unnamed_tuple_ui(list.size()); - mupa_ = isl::multi_union_pw_aff(space, list); + auto spaceBand = space.add_unnamed_tuple_ui(list.size()); + mupa_ = isl::MultiUnionPwAff(spaceBand, list); std::copy( coincident_.begin() + pos + n, @@ -250,17 +252,6 @@ void ScheduleTreeBand::drop(size_t pos, size_t n) { TC_CHECK_EQ(nBegin - n, nMember()); } -isl::multi_union_pw_aff ScheduleTreeBand::memberRange(size_t first, size_t n) - const { - auto list = mupa_.get_union_pw_aff_list(); - auto space = mupa_.get_space().params().add_unnamed_tuple_ui(n); - auto end = first + n; - TC_CHECK_LE(end, nMember()); - list = list.drop(end, nMember() - end); - list = list.drop(0, first); - return isl::multi_union_pw_aff(space, list); -} - std::unique_ptr ScheduleTreeThreadSpecificMarker::make( isl::ctx ctx, diff --git a/tc/core/polyhedral/schedule_tree_elem.h b/tc/core/polyhedral/schedule_tree_elem.h index 5d782049d..3e1dd42aa 100644 --- a/tc/core/polyhedral/schedule_tree_elem.h +++ b/tc/core/polyhedral/schedule_tree_elem.h @@ -23,6 +23,7 @@ #include "tc/external/isl.h" #include "tc/core/check.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/mapping_types.h" #include "tc/core/polyhedral/schedule_tree.h" @@ -63,7 +64,7 @@ struct ScheduleTreeContext : public ScheduleTree { } public: - isl::set context_; + isl::Set context_; }; struct ScheduleTreeDomain : public ScheduleTree { @@ -99,7 +100,7 @@ struct ScheduleTreeDomain : public ScheduleTree { } public: - isl::union_set domain_; + isl::UnionSet domain_; }; struct ScheduleTreeExtension : public ScheduleTree { @@ -135,7 +136,7 @@ struct ScheduleTreeExtension : public ScheduleTree { } public: - isl::union_map extension_; + isl::UnionMap extension_; }; struct ScheduleTreeFilter : public ScheduleTree { @@ -171,14 +172,14 @@ struct ScheduleTreeFilter : public ScheduleTree { } public: - isl::union_set filter_; + isl::UnionSet filter_; }; struct ScheduleTreeMapping : public ScheduleTree { public: using Mapping = std::unordered_map< mapping::MappingId, - isl::union_pw_aff, + isl::UnionPwAffOn, typename mapping::MappingId::Hash>; static constexpr detail::ScheduleTreeType NodeType = @@ -215,7 +216,7 @@ struct ScheduleTreeMapping : public ScheduleTree { // Mapping from identifiers to affine functions on domain elements. const Mapping mapping; // Assignment of the affine functions to the identifiers as parameters. - isl::union_set filter_; + isl::UnionSet filter_; }; struct ScheduleTreeSequence : public ScheduleTree { @@ -309,7 +310,7 @@ struct ScheduleTreeBand : public ScheduleTree { // Replace "mupa" by its greatest integer part to ensure that the // schedule is always integral. static std::unique_ptr make( - isl::multi_union_pw_aff mupa, + isl::MultiUnionPwAff mupa, bool permutable, std::vector coincident, std::vector unroll, @@ -331,11 +332,21 @@ struct ScheduleTreeBand : public ScheduleTree { // Extract the range of "n" members starting at "first" // (in an anonymous space). - isl::multi_union_pw_aff memberRange(size_t first, size_t n) const; + template + isl::MultiUnionPwAff memberRange(size_t first, size_t n) + const { + auto list = mupa_.get_union_pw_aff_list(); + auto space = mupa_.get_space().params().add_unnamed_tuple_ui(n); + auto end = first + n; + TC_CHECK_LE(end, nMember()); + list = list.drop(end, nMember() - end); + list = list.drop(0, first); + return isl::MultiUnionPwAff(space, list); + } public: bool permutable_{false}; - isl::multi_union_pw_aff mupa_; + isl::MultiUnionPwAff mupa_; std::vector coincident_; // For each member, should the corresponding loop in the generated code diff --git a/tc/core/polyhedral/schedule_tree_matcher-inl.h b/tc/core/polyhedral/schedule_tree_matcher-inl.h index 41a4ba434..1a83746cd 100644 --- a/tc/core/polyhedral/schedule_tree_matcher-inl.h +++ b/tc/core/polyhedral/schedule_tree_matcher-inl.h @@ -16,6 +16,7 @@ #pragma once #include "tc/core/check.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_tree.h" #include "tc/core/polyhedral/schedule_tree_elem.h" @@ -291,5 +292,17 @@ inline std::vector match( return matchDFSPreorder(matcher, tree); } +template +inline bool isSingleReductionWithin( + isl::UnionSet domain, + isl::MultiUnionPwAff prefix, + const Scop& scop) { + auto reductions = scop.body.reductions; + reductions = reductions.intersect_domain(domain); + auto prefixMap = prefix.toUnionMap(); + auto prefixToReduction = reductions.apply_domain(prefixMap); + return prefixToReduction.is_single_valued(); +} + } // namespace polyhedral } // namespace tc diff --git a/tc/core/polyhedral/schedule_tree_matcher.h b/tc/core/polyhedral/schedule_tree_matcher.h index fd4b986a9..8b51f9ce9 100644 --- a/tc/core/polyhedral/schedule_tree_matcher.h +++ b/tc/core/polyhedral/schedule_tree_matcher.h @@ -19,6 +19,7 @@ #include #include +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_tree.h" #include "tc/core/polyhedral/scop.h" @@ -27,14 +28,17 @@ namespace polyhedral { // Return the union of the reduction update statements // that appear in "domain". -isl::union_set reductionUpdates(isl::union_set domain, const Scop& scop); +isl::UnionSet reductionUpdates( + isl::UnionSet domain, + const Scop& scop); // Does "prefix" partition "domain" into individual reductions? // In particular, do the elements of "domain" access a single tensor // element within "prefix"? +template bool isSingleReductionWithin( - isl::union_set domain, - isl::multi_union_pw_aff prefix, + isl::UnionSet domain, + isl::MultiUnionPwAff prefix, const Scop& scop); } // namespace polyhedral diff --git a/tc/core/polyhedral/schedule_utils-inl.h b/tc/core/polyhedral/schedule_utils-inl.h new file mode 100644 index 000000000..416f0111c --- /dev/null +++ b/tc/core/polyhedral/schedule_utils-inl.h @@ -0,0 +1,216 @@ +/** + * Copyright (c) 2017-2018, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "tc/core/check.h" +#include "tc/core/polyhedral/domain_types.h" +#include "tc/core/polyhedral/schedule_tree.h" +#include "tc/core/polyhedral/schedule_tree_elem.h" +#include "tc/external/isl.h" + +namespace tc { +namespace polyhedral { + +template +isl::UnionMap extendSchedule( + const detail::ScheduleTree* node, + isl::UnionMap schedule) { + using namespace polyhedral::detail; + + if (auto bandElem = node->as()) { + if (bandElem->nMember() > 0) { + schedule = schedule.template flat_range_product( + bandElem->mupa_.toUnionMap()); + } + } else if (auto filterElem = node->as()) { + schedule = schedule.intersect_domain(filterElem->filter_); + } else if (auto extensionElem = node->as()) { + // FIXME: we may need to restrict the range of reversed extension map to + // schedule values that correspond to active domain elements at this + // point. + auto extension = extensionElem->extension_.reverse(); + auto specializedExtension = isl::UnionMap(extension); + schedule = + schedule.unite(specializedExtension.intersect_range(schedule.range())); + } + + return schedule; +} + +namespace detail { +template +inline isl::UnionMap partialScheduleImpl( + const ScheduleTree* root, + const ScheduleTree* node, + bool useNode) { + auto nodes = node->ancestors(root); + if (useNode) { + nodes.push_back(node); + } + TC_CHECK_GT(nodes.size(), 0u) << "root node does not have a prefix schedule"; + auto domain = root->as(); + TC_CHECK(domain); + auto schedule = + isl::UnionMap::from_domain(domain->domain_); + for (auto anc : nodes) { + if (anc->as()) { + TC_CHECK(anc == root); + } else { + schedule = extendSchedule(anc, schedule); + } + } + return schedule; +} +} // namespace detail + +template +inline isl::UnionMap prefixSchedule( + const detail::ScheduleTree* root, + const detail::ScheduleTree* node) { + return detail::partialScheduleImpl(root, node, false); +} + +template +inline isl::UnionMap partialSchedule( + const detail::ScheduleTree* root, + const detail::ScheduleTree* node) { + return detail::partialScheduleImpl(root, node, true); +} + +namespace detail { + +template +inline std::vector filterType( + const std::vector& vec) { + std::vector result; + for (auto e : vec) { + if (e->as()) { + result.push_back(e); + } + } + return result; +} + +template +inline T +foldl(const std::vector vec, Func op, T init = T()) { + T value = init; + for (auto st : vec) { + value = op(st, value); + } + return value; +} + +} // namespace detail + +template +inline isl::MultiUnionPwAff infixScheduleMupa( + const detail::ScheduleTree* root, + const detail::ScheduleTree* relativeRoot, + const detail::ScheduleTree* tree) { + using namespace polyhedral::detail; + + auto domainElem = root->as(); + TC_CHECK(domainElem); + auto domain = domainElem->domain_.universe(); + auto zero = isl::MultiVal::zero( + domain.get_space().add_unnamed_tuple_ui(0)); + auto prefix = isl::MultiUnionPwAff(domain, zero); + prefix = foldl( + filterType(tree->ancestors(relativeRoot)), + [](const ScheduleTree* st, + isl::MultiUnionPwAff pref) { + auto mupa = st->as()->mupa_; + return pref.template flat_range_product(mupa); + }, + prefix); + return prefix; +} + +template +inline isl::MultiUnionPwAff prefixScheduleMupa( + const detail::ScheduleTree* root, + const detail::ScheduleTree* tree) { + return infixScheduleMupa(root, root, tree); +} + +template +inline isl::MultiUnionPwAff partialScheduleMupa( + const detail::ScheduleTree* root, + const detail::ScheduleTree* tree) { + using namespace polyhedral::detail; + + auto prefix = prefixScheduleMupa(root, tree); + auto band = tree->as(); + return band ? prefix.template flat_range_product(band->mupa_) + : prefix; +} + +/* + * Extract a mapping from the domain elements active at "tree" + * to identifiers "ids", where all branches in "tree" + * are assumed to have been mapped to these identifiers. + * The result lives in a space of the form "tupleId"["ids"...]. + */ +template +isl::MultiUnionPwAff extractDomainToIds( + const detail::ScheduleTree* root, + const detail::ScheduleTree* tree, + const std::vector& ids, + isl::id tupleId) { + using namespace polyhedral::detail; + + auto paramSpace = isl::Space<>(tree->ctx_, 0); + auto empty = isl::UnionSet::empty(paramSpace); + auto space = + paramSpace.add_named_tuple_id_ui(tupleId, ids.size()); + auto zero = isl::MultiVal::zero(space); + auto domainToIds = isl::MultiUnionPwAff(empty, zero); + + for (auto mapping : tree->collect(tree, ScheduleTreeType::Mapping)) { + auto mappingNode = mapping->as(); + auto list = isl::UnionPwAffListOn(tree->ctx_, ids.size()); + for (auto id : ids) { + if (mappingNode->mapping.count(id) == 0) { + break; + } + auto idMap = mappingNode->mapping.at(id); + list = list.add(idMap); + } + // Ignore this node if it does not map to all required ids. + if (static_cast(list.size()) != ids.size()) { + continue; + } + auto nodeToIds = isl::MultiUnionPwAff(space, list); + auto active = activeDomainPoints(root, mapping); + TC_CHECK(active.intersect(domainToIds.domain()).is_empty()) + << "conflicting mappings; are the filters in the tree disjoint?"; + nodeToIds = nodeToIds.intersect_domain(active); + domainToIds = domainToIds.union_add(nodeToIds); + } + + auto active = activeDomainPoints(root, tree); + TC_CHECK(active.is_subset(domainToIds.domain())) + << "not all domain points of\n" + << active << "\nwere mapped to the required ids"; + + return domainToIds; +} + +} // namespace polyhedral +} // namespace tc diff --git a/tc/core/polyhedral/schedule_utils.cc b/tc/core/polyhedral/schedule_utils.cc index 37a1a7d96..8e6f26bad 100644 --- a/tc/core/polyhedral/schedule_utils.cc +++ b/tc/core/polyhedral/schedule_utils.cc @@ -28,68 +28,13 @@ using namespace detail; using std::ostream; using std::vector; -isl::union_map extendSchedule( - const ScheduleTree* node, - isl::union_map schedule) { - if (auto bandElem = node->as()) { - if (bandElem->nMember() > 0) { - schedule = - schedule.flat_range_product(isl::union_map::from(bandElem->mupa_)); - } - } else if (auto filterElem = node->as()) { - schedule = schedule.intersect_domain(filterElem->filter_); - } else if (auto extensionElem = node->as()) { - // FIXME: we may need to restrict the range of reversed extension map to - // schedule values that correspond to active domain elements at this - // point. - schedule = schedule.unite( - extensionElem->extension_.reverse().intersect_range(schedule.range())); - } - - return schedule; -} - -namespace { -isl::union_map partialScheduleImpl( - const ScheduleTree* root, - const ScheduleTree* node, - bool useNode) { - auto nodes = node->ancestors(root); - if (useNode) { - nodes.push_back(node); - } - TC_CHECK_GT(nodes.size(), 0u) << "root node does not have a prefix schedule"; - auto domain = root->as(); - TC_CHECK(domain); - auto schedule = isl::union_map::from_domain(domain->domain_); - for (auto anc : nodes) { - if (anc->as()) { - TC_CHECK(anc == root); - } else { - schedule = extendSchedule(anc, schedule); - } - } - return schedule; -} -} // namespace - -isl::union_map prefixSchedule( - const ScheduleTree* root, - const ScheduleTree* node) { - return partialScheduleImpl(root, node, false); -} - -isl::union_map partialSchedule( - const ScheduleTree* root, - const ScheduleTree* node) { - return partialScheduleImpl(root, node, true); -} - namespace { /* * If "node" is any filter, then intersect "domain" with that filter. */ -isl::union_set applyFilter(isl::union_set domain, const ScheduleTree* node) { +isl::UnionSet applyFilter( + isl::UnionSet domain, + const ScheduleTree* node) { if (auto filterElem = node->as()) { return domain.intersect(filterElem->filter_); } @@ -99,7 +44,9 @@ isl::union_set applyFilter(isl::union_set domain, const ScheduleTree* node) { /* * If "node" is a mapping, then intersect "domain" with its filter. */ -isl::union_set applyMapping(isl::union_set domain, const ScheduleTree* node) { +isl::UnionSet applyMapping( + isl::UnionSet domain, + const ScheduleTree* node) { if (auto filterElem = node->as()) { return domain.intersect(filterElem->filter_); } @@ -112,10 +59,11 @@ isl::union_set applyMapping(isl::union_set domain, const ScheduleTree* node) { // Domain elements are introduced by the root domain node. Some nodes // refine this set of elements based on "filter". Extension nodes // are considered to introduce additional domain points. -isl::union_set collectDomain( +isl::UnionSet collectDomain( const ScheduleTree* root, const vector& nodes, - isl::union_set (*filter)(isl::union_set domain, const ScheduleTree* node)) { + isl::UnionSet ( + *filter)(isl::UnionSet domain, const ScheduleTree* node)) { auto domainElem = root->as(); TC_CHECK(domainElem) << "root must be a Domain node" << *root; @@ -124,7 +72,7 @@ isl::union_set collectDomain( for (auto anc : nodes) { domain = filter(domain, anc); if (auto extensionElem = anc->as()) { - auto parentSchedule = prefixSchedule(root, anc); + auto parentSchedule = prefixSchedule(root, anc); auto extension = extensionElem->extension_; TC_CHECK(parentSchedule) << "missing root domain node"; parentSchedule = parentSchedule.intersect_domain(domain); @@ -136,7 +84,7 @@ isl::union_set collectDomain( // Get the set of domain elements that are active below // the given branch of nodes. -isl::union_set activeDomainPointsHelper( +isl::UnionSet activeDomainPointsHelper( const ScheduleTree* root, const vector& nodes) { return collectDomain(root, nodes, &applyFilter); @@ -144,19 +92,19 @@ isl::union_set activeDomainPointsHelper( } // namespace -isl::union_set prefixMappingFilter( +isl::UnionSet prefixMappingFilter( const ScheduleTree* root, const ScheduleTree* node) { return collectDomain(root, node->ancestors(root), &applyMapping); } -isl::union_set activeDomainPoints( +isl::UnionSet activeDomainPoints( const ScheduleTree* root, const ScheduleTree* node) { return activeDomainPointsHelper(root, node->ancestors(root)); } -isl::union_set activeDomainPointsBelow( +isl::UnionSet activeDomainPointsBelow( const ScheduleTree* root, const ScheduleTree* node) { auto ancestors = node->ancestors(root); @@ -186,111 +134,5 @@ vector collectScheduleTreesPath( return res; } -namespace { - -template -vector filterType(const vector& vec) { - vector result; - for (auto e : vec) { - if (e->as()) { - result.push_back(e); - } - } - return result; -} - -template -T foldl(const vector vec, Func op, T init = T()) { - T value = init; - for (auto st : vec) { - value = op(st, value); - } - return value; -} - -} // namespace - -isl::multi_union_pw_aff infixScheduleMupa( - const ScheduleTree* root, - const ScheduleTree* relativeRoot, - const ScheduleTree* tree) { - auto domainElem = root->as(); - TC_CHECK(domainElem); - auto domain = domainElem->domain_.universe(); - auto zero = isl::multi_val::zero(domain.get_space().set_from_params()); - auto prefix = isl::multi_union_pw_aff(domain, zero); - prefix = foldl( - filterType(tree->ancestors(relativeRoot)), - [](const ScheduleTree* st, isl::multi_union_pw_aff pref) { - auto mupa = st->as()->mupa_; - return pref.flat_range_product(mupa); - }, - prefix); - return prefix; -} - -isl::multi_union_pw_aff prefixScheduleMupa( - const ScheduleTree* root, - const ScheduleTree* tree) { - return infixScheduleMupa(root, root, tree); -} - -isl::multi_union_pw_aff partialScheduleMupa( - const detail::ScheduleTree* root, - const detail::ScheduleTree* tree) { - auto prefix = prefixScheduleMupa(root, tree); - auto band = tree->as(); - return band ? prefix.flat_range_product(band->mupa_) : prefix; -} - -/* - * Extract a mapping from the domain elements active at "tree" - * to identifiers "ids", where all branches in "tree" - * are assumed to have been mapped to these identifiers. - * The result lives in a space of the form "tupleId"["ids"...]. - */ -isl::multi_union_pw_aff extractDomainToIds( - const detail::ScheduleTree* root, - const detail::ScheduleTree* tree, - const std::vector& ids, - isl::id tupleId) { - using namespace polyhedral::detail; - - auto space = isl::space(tree->ctx_, 0); - auto empty = isl::union_set::empty(space); - space = space.add_named_tuple_id_ui(tupleId, ids.size()); - auto zero = isl::multi_val::zero(space); - auto domainToIds = isl::multi_union_pw_aff(empty, zero); - - for (auto mapping : tree->collect(tree, ScheduleTreeType::Mapping)) { - auto mappingNode = mapping->as(); - auto list = isl::union_pw_aff_list(tree->ctx_, ids.size()); - for (auto id : ids) { - if (mappingNode->mapping.count(id) == 0) { - break; - } - auto idMap = mappingNode->mapping.at(id); - list = list.add(idMap); - } - // Ignore this node if it does not map to all required ids. - if (static_cast(list.size()) != ids.size()) { - continue; - } - auto nodeToIds = isl::multi_union_pw_aff(space, list); - auto active = activeDomainPoints(root, mapping); - TC_CHECK(active.intersect(domainToIds.domain()).is_empty()) - << "conflicting mappings; are the filters in the tree disjoint?"; - nodeToIds = nodeToIds.intersect_domain(active); - domainToIds = domainToIds.union_add(nodeToIds); - } - - auto active = activeDomainPoints(root, tree); - TC_CHECK(active.is_subset(domainToIds.domain())) - << "not all domain points of\n" - << active << "\nwere mapped to the required ids"; - - return domainToIds; -} - } // namespace polyhedral } // namespace tc diff --git a/tc/core/polyhedral/schedule_utils.h b/tc/core/polyhedral/schedule_utils.h index adca3decf..26ffb04b4 100644 --- a/tc/core/polyhedral/schedule_utils.h +++ b/tc/core/polyhedral/schedule_utils.h @@ -19,6 +19,7 @@ #include "tc/core/check.h" #include "tc/core/constants.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_tree.h" #include "tc/core/polyhedral/schedule_tree_elem.h" @@ -46,18 +47,21 @@ std::vector collectScheduleTreesPath( // Given a schedule defined by the ancestors of the given node, // extend it to a schedule that also covers the node itself. -isl::union_map extendSchedule( +template +isl::UnionMap extendSchedule( const detail::ScheduleTree* node, - isl::union_map schedule); + isl::UnionMap schedule); // Get the partial schedule defined by ancestors of the given node and the node // itself. -isl::union_map partialSchedule( +template +isl::UnionMap partialSchedule( const detail::ScheduleTree* root, const detail::ScheduleTree* node); // Return the schedule defined by the ancestors of the given node. -isl::union_map prefixSchedule( +template +isl::UnionMap prefixSchedule( const detail::ScheduleTree* root, const detail::ScheduleTree* node); @@ -68,7 +72,8 @@ isl::union_map prefixSchedule( // function on the universe domain of the schedule tree. // Note that this function does not take into account // any intermediate filter nodes. -isl::multi_union_pw_aff infixScheduleMupa( +template +isl::MultiUnionPwAff infixScheduleMupa( const detail::ScheduleTree* root, const detail::ScheduleTree* relativeRoot, const detail::ScheduleTree* tree); @@ -78,7 +83,8 @@ isl::multi_union_pw_aff infixScheduleMupa( // function on the universe domain of the schedule tree. // Note that unlike isl_schedule_node_get_prefix_schedule_multi_union_pw_aff, // this function does not take into account any intermediate filter nodes. -isl::multi_union_pw_aff prefixScheduleMupa( +template +isl::MultiUnionPwAff prefixScheduleMupa( const detail::ScheduleTree* root, const detail::ScheduleTree* tree); @@ -86,7 +92,8 @@ isl::multi_union_pw_aff prefixScheduleMupa( // including that of the node itself. // Note that this function does not take into account // any intermediate filter nodes. -isl::multi_union_pw_aff partialScheduleMupa( +template +isl::MultiUnionPwAff partialScheduleMupa( const detail::ScheduleTree* root, const detail::ScheduleTree* tree); @@ -94,7 +101,7 @@ isl::multi_union_pw_aff partialScheduleMupa( // point is active if it was not filtered away on the path from the // root to the node. The root must be a domain element, otherwise no // elements would be considered active. -isl::union_set activeDomainPoints( +isl::UnionSet activeDomainPoints( const detail::ScheduleTree* root, const detail::ScheduleTree* node); @@ -102,13 +109,13 @@ isl::union_set activeDomainPoints( // point is active if it was not filtered away on the path from the // root to the node. The root must be a domain element, otherwise no // elements would be considered active. -isl::union_set activeDomainPointsBelow( +isl::UnionSet activeDomainPointsBelow( const detail::ScheduleTree* root, const detail::ScheduleTree* node); // Collect the outer block/thread identifier mappings // into a filter on the active domain elements. -isl::union_set prefixMappingFilter( +isl::UnionSet prefixMappingFilter( const detail::ScheduleTree* root, const detail::ScheduleTree* node); @@ -116,7 +123,8 @@ isl::union_set prefixMappingFilter( // rooted at "root") to identifiers "ids", where all branches in "tree" are // assumed to have been mapped to these identifiers. The result lives in a // space of the form "tupleId"["ids"...]. -isl::multi_union_pw_aff extractDomainToIds( +template +isl::MultiUnionPwAff extractDomainToIds( const detail::ScheduleTree* root, const detail::ScheduleTree* tree, const std::vector& ids, @@ -140,3 +148,5 @@ bool isMappingTo(const detail::ScheduleTree* tree) { } // namespace polyhedral } // namespace tc + +#include "tc/core/polyhedral/schedule_utils-inl.h" diff --git a/tc/core/polyhedral/scop.cc b/tc/core/polyhedral/scop.cc index 71d3bc176..4046dedd4 100644 --- a/tc/core/polyhedral/scop.cc +++ b/tc/core/polyhedral/scop.cc @@ -27,6 +27,7 @@ #include "tc/core/functional.h" #include "tc/core/halide2isl.h" #include "tc/core/polyhedral/body.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/memory_promotion.h" #include "tc/core/polyhedral/schedule_isl_conversion.h" #include "tc/core/polyhedral/schedule_transforms.h" @@ -51,7 +52,7 @@ ScopUPtr Scop::makeScop( halide2isl::SymbolTable sym = halide2isl::makeSymbolTable(components); - isl::space paramSpace = halide2isl::makeParamSpace(ctx, sym.params); + auto paramSpace = halide2isl::makeParamSpace(ctx, sym.params); ScopUPtr scop(new Scop()); scop->halide.params = sym.params; @@ -84,7 +85,7 @@ ScopUPtr Scop::makeScop( return makeScop(ctx, tc2halide::translate(ctx, treeRef, compilerOptions)); } -isl::union_set& Scop::domainRef() { +isl::UnionSet& Scop::domainRef() { auto dom = scheduleRoot()->as(); TC_CHECK(dom) << "root is not a domain in: " << *scheduleRoot(); // TODO: activate this when the invariant has a chance of working (i.e. we @@ -97,7 +98,7 @@ isl::union_set& Scop::domainRef() { return dom->domain_; } -const isl::union_set Scop::domain() const { +const isl::UnionSet Scop::domain() const { return const_cast(this)->domainRef(); } @@ -264,7 +265,7 @@ void Scop::promoteEverythingAt(std::vector pos) { auto tree = scheduleRoot()->child(pos); checkFiltersDisjointStatements(scheduleRoot()); - auto schedule = partialSchedule(root, tree); + auto schedule = partialSchedule(root, tree); auto groupMap = TensorReferenceGroup::accessedWithin(schedule, body); for (auto& p : groupMap) { @@ -299,11 +300,11 @@ namespace { using namespace tc::polyhedral; -isl::union_map computeDependences( - isl::union_map sources, - isl::union_map sinks, +isl::UnionMap computeDependences( + isl::UnionMap sources, + isl::UnionMap sinks, isl::schedule schedule) { - auto uai = isl::union_access_info(sinks); + auto uai = sinks.asUnionAccessInfo(); uai = uai.set_may_source(sources); uai = uai.set_schedule(schedule); auto flow = uai.compute_flow(); @@ -368,8 +369,9 @@ void Scop::computeAllDependences() { dependences = flowDeps.unite(falseDeps).coalesce(); } -isl::union_map Scop::activeDependences(detail::ScheduleTree* tree) { - auto prefix = prefixScheduleMupa(scheduleRoot(), tree); +isl::UnionMap Scop::activeDependences( + detail::ScheduleTree* tree) { + auto prefix = prefixScheduleMupa(scheduleRoot(), tree); auto domain = activeDomainPoints(scheduleRoot(), tree); auto active = dependences; active = active.intersect_domain(domain); @@ -484,7 +486,7 @@ void Scop::reschedule( auto parentTree = tree->ancestor(root, 1); auto treePos = tree->positionInParent(parentTree); auto domain = activeDomainPoints(root, tree); - auto prefix = prefixScheduleMupa(root, tree); + auto prefix = prefixScheduleMupa(root, tree); // Restrict the constraints to domain points reachable from point loops // and update the current prefix. @@ -515,12 +517,12 @@ const Halide::OutputImageParam& Scop::findArgument(isl::id id) const { return *halide.inputs.begin(); } -isl::aff Scop::makeIslAffFromStmtExpr(isl::id stmtId, const Halide::Expr& e) - const { +isl::AffOn Scop::makeIslAffFromStmtExpr( + isl::id stmtId, + const Halide::Expr& e) const { auto domain = halide.domains.at(stmtId); auto aff = halide2isl::makeIslAffFromExpr(domain.paramSpace, e); - aff = aff.unbind_params_insert_domain(domain.tuple); - return aff; + return aff.unbind_params_insert_domain(domain.tuple); } } // namespace polyhedral diff --git a/tc/core/polyhedral/scop.h b/tc/core/polyhedral/scop.h index 796974d9f..572fb833a 100644 --- a/tc/core/polyhedral/scop.h +++ b/tc/core/polyhedral/scop.h @@ -27,6 +27,7 @@ #include "tc/core/halide2isl.h" #include "tc/core/mapping_options.h" #include "tc/core/polyhedral/body.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/core/polyhedral/schedule_transforms.h" #include "tc/core/polyhedral/schedule_tree.h" #include "tc/core/tc2halide.h" @@ -92,7 +93,7 @@ struct Scop { // The schedule tree of the scop does not necessarily have // a context node. Call updateTopLevelContext on the schedule tree // to introduce or refine such a context node. - isl::set context() const { + isl::Set<> context() const { auto ctx = domain().get_ctx(); auto context = halide2isl::makeParamContext(ctx, halide.params); return context.intersect(makeContext(parameterValues)); @@ -131,7 +132,7 @@ struct Scop { // Returns a set that specializes the named scop's subset of // parameter space to the integer values passed to the function. template - isl::set makeContext( + isl::Set<> makeContext( const std::unordered_map& sizes = std::unordered_map()) const { auto s = domain().get_space(); @@ -421,7 +422,9 @@ struct Scop { // Return a null isl::aff if the expression is not affine. Fail if any // of the variables does not correspond to a parameter or // an instance identifier of the statement. - isl::aff makeIslAffFromStmtExpr(isl::id stmtId, const Halide::Expr& e) const; + isl::AffOn makeIslAffFromStmtExpr( + isl::id stmtId, + const Halide::Expr& e) const; // Promote a tensor reference group to a storage of a given "kind", // inserting the copy @@ -483,7 +486,8 @@ struct Scop { void computeAllDependences(); // Return the set of dependences that are active // at the given position. - isl::union_map activeDependences(detail::ScheduleTree* tree); + isl::UnionMap activeDependences( + detail::ScheduleTree* tree); public: // Halide stuff @@ -505,17 +509,17 @@ struct Scop { // By analogy with generalized functions, the domain is the "support" part // of the ScheduleTree "function". private: - isl::union_set& domainRef(); + isl::UnionSet& domainRef(); public: - const isl::union_set domain() const; + const isl::UnionSet domain() const; // The parameter values of a specialized Scop. std::unordered_map parameterValues; Body body; // RAW, WAR, and WAW dependences - isl::union_map dependences; + isl::UnionMap dependences; private: // By analogy with generalized functions, a ScheduleTree is a (piecewise diff --git a/tc/core/polyhedral/separation.cc b/tc/core/polyhedral/separation-inl.h similarity index 75% rename from tc/core/polyhedral/separation.cc rename to tc/core/polyhedral/separation-inl.h index c5d38436e..d9d62c44f 100644 --- a/tc/core/polyhedral/separation.cc +++ b/tc/core/polyhedral/separation-inl.h @@ -13,35 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#pragma once -#include "tc/core/polyhedral/separation.h" #include "tc/core/check.h" +#include "tc/core/polyhedral/domain_types.h" #include "tc/external/isl.h" namespace tc { namespace polyhedral { -isl::union_set partialTargetTiles( - isl::union_set domain, - isl::multi_union_pw_aff prefix, - isl::multi_union_pw_aff pretile, - isl::multi_val size) { +template +isl::UnionSet partialTargetTiles( + isl::UnionSet domain, + isl::MultiUnionPwAff prefix, + isl::MultiUnionPwAff pretile, + isl::MultiVal size) { auto space = pretile.get_space(); - auto tile = isl::multi_aff::identity(space.map_from_set()); + auto tile = isl::MultiAff::identity(space.map_from_set()); tile = tile.scale_down(size).floor(); - auto tileMap = isl::map(tile); + auto tileMap = tile.asMap(); // Relation between pairs of elements in the same target tile. - auto sameTile = isl::union_map(tileMap.apply_range(tileMap.reverse())); + auto sameTile = tileMap.apply_range(tileMap.reverse()).asUnionMap(); // Mapping between domain elements and pairs of prefix and target values. // D -> [P -> T] auto schedule = prefix.range_product(pretile); - auto scheduleMap = isl::union_map::from(schedule); + auto scheduleMap = schedule.toUnionMap(); // Mapping between prefix values and target values // for some common domain element // P -> T TC_CHECK(domain.is_subset(scheduleMap.domain())); - auto target = domain.apply(scheduleMap).unwrap(); + auto target = domain.apply(scheduleMap).template unwrap(); // Mapping between prefix values and target values // for some common domain element, extended to complete target tiles. // P -> Tc diff --git a/tc/core/polyhedral/separation.h b/tc/core/polyhedral/separation.h index 668db7c93..a08c8b50d 100644 --- a/tc/core/polyhedral/separation.h +++ b/tc/core/polyhedral/separation.h @@ -15,6 +15,8 @@ */ #pragma once +#include "tc/core/polyhedral/domain_types.h" + #include "tc/external/isl.h" namespace tc { @@ -25,11 +27,14 @@ namespace polyhedral { * Return the elements in "domain" that map to partial tiles * in this space for fixed values of "prefix". */ -isl::union_set partialTargetTiles( - isl::union_set domain, - isl::multi_union_pw_aff prefix, - isl::multi_union_pw_aff pretile, - isl::multi_val size); +template +isl::UnionSet partialTargetTiles( + isl::UnionSet domain, + isl::MultiUnionPwAff prefix, + isl::MultiUnionPwAff pretile, + isl::MultiVal size); } // namespace polyhedral } // namespace tc + +#include "tc/core/polyhedral/separation-inl.h" diff --git a/tc/core/polyhedral/unroll.cc b/tc/core/polyhedral/unroll.cc index e90e78f16..8353699d7 100644 --- a/tc/core/polyhedral/unroll.cc +++ b/tc/core/polyhedral/unroll.cc @@ -100,7 +100,7 @@ isl::val boundInstancesAndMarkUnroll( auto outerMap = prefix; if (i > 0) { list = list.drop(i, 1); - auto outerSpace = space.add_unnamed_tuple_ui(list.size()); + auto outerSpace = space.add_unnamed_tuple_ui(list.size()); auto outer = isl::multi_union_pw_aff(outerSpace, list); outerMap = outerMap.flat_range_product(isl::union_map::from(outer)); } @@ -116,7 +116,7 @@ isl::val boundInstancesAndMarkUnroll( isl::val boundInstancesAndMarkUnroll( detail::ScheduleTree* st, - isl::union_map prefix, + isl::UnionMap prefix, isl::val unrollFactor); /* @@ -134,7 +134,7 @@ isl::val boundInstancesAndMarkUnroll( */ isl::val boundChildrenInstancesAndMarkUnroll( detail::ScheduleTree* st, - isl::union_map prefix, + isl::UnionMap prefix, isl::val unrollFactor) { if (st->children().size() == 0) { return isl::val::one(unrollFactor.get_ctx()); @@ -163,7 +163,7 @@ isl::val boundChildrenInstancesAndMarkUnroll( */ isl::val boundInstancesAndMarkUnroll( detail::ScheduleTree* st, - isl::union_map prefix, + isl::UnionMap prefix, isl::val unrollFactor) { auto bound = boundChildrenInstancesAndMarkUnroll(st, prefix, unrollFactor); @@ -184,7 +184,7 @@ void markUnroll( } auto unrollVal = isl::val(st->ctx_, unroll); - auto prefix = prefixSchedule(root, st); + auto prefix = prefixSchedule(root, st); prefix = prefix.intersect_domain(prefixMappingFilter(root, st)); boundInstancesAndMarkUnroll(st, prefix, unrollVal); } diff --git a/tc/core/polyhedral/utils.cc b/tc/core/polyhedral/utils.cc index 80cf11f78..07e1b83ee 100644 --- a/tc/core/polyhedral/utils.cc +++ b/tc/core/polyhedral/utils.cc @@ -15,6 +15,8 @@ */ #include "tc/core/polyhedral/utils.h" +#include "tc/core/polyhedral/domain_types.h" + namespace tc { namespace polyhedral { @@ -24,15 +26,15 @@ namespace polyhedral { * of the user. * Since some names are required, use names of the form "__tc_tensor_arg*". */ -isl::multi_id -constructTensorTuple(isl::space paramSpace, isl::id tensorId, size_t dim) { - auto tensorSpace = paramSpace.add_named_tuple_id_ui(tensorId, dim); +isl::MultiId +constructTensorTuple(isl::Space<> paramSpace, isl::id tensorId, size_t dim) { + auto tensorSpace = paramSpace.add_named_tuple_id_ui(tensorId, dim); isl::id_list tensorArgs(paramSpace.get_ctx(), 0); for (size_t i = 0; i < dim; ++i) { auto name = std::string("__tc_tensor_arg") + std::to_string(i); tensorArgs = tensorArgs.add(isl::id(paramSpace.get_ctx(), name)); } - return isl::multi_id(tensorSpace, tensorArgs); + return isl::MultiId(tensorSpace, tensorArgs); } } // namespace polyhedral diff --git a/tc/core/polyhedral/utils.h b/tc/core/polyhedral/utils.h index 459bca648..31512f9f3 100644 --- a/tc/core/polyhedral/utils.h +++ b/tc/core/polyhedral/utils.h @@ -15,6 +15,7 @@ */ #pragma once +#include "tc/core/polyhedral/domain_types.h" #include "tc/external/isl.h" namespace tc { @@ -24,8 +25,8 @@ namespace polyhedral { * dimension "dim" from the parameter space "paramSpace", * without any specific names for the indices. */ -isl::multi_id -constructTensorTuple(isl::space paramSpace, isl::id tensorId, size_t dim); +isl::MultiId +constructTensorTuple(isl::Space<> paramSpace, isl::id tensorId, size_t dim); } // namespace polyhedral } // namespace tc diff --git a/tc/external/detail/islpp-inl.h b/tc/external/detail/islpp-inl.h index 6991d257a..8ce3579f9 100644 --- a/tc/external/detail/islpp-inl.h +++ b/tc/external/detail/islpp-inl.h @@ -20,6 +20,16 @@ namespace isl { /////////////////////////////////////////////////////////////////////////////// // Operations on isl::aff to perform arithmetic and create/combine with sets /////////////////////////////////////////////////////////////////////////////// +template +inline auto operator*(isl::val V, T A) -> decltype(A.scale(V)) { + return A.scale(V); +} + +template +inline auto operator*(T A, isl::val V) -> decltype(A.scale(V)) { + return V * A; +} + inline isl::aff operator*(int i, isl::aff A) { isl::val V(isl::val(A.get_ctx(), i)); return A * V; @@ -29,42 +39,37 @@ inline isl::aff operator*(isl::aff A, int i) { return i * A; } -inline isl::aff operator*(isl::val V, isl::aff A) { - return A.scale(V); -} - -inline isl::aff operator*(isl::aff A, isl::val V) { - return V * A; -} - inline isl::aff operator/(isl::aff A, int i) { return A.scale_down(isl::val(A.get_ctx(), i)); } template -inline isl::aff operator+(int i, T A) { +inline T operator+(int i, T A) { return A.add_constant_si(i); } -inline isl::aff operator+(isl::aff A, isl::val v) { - isl::aff T(isl::local_space(A.get_space().domain()), v); - return A.add(T); +template +inline auto operator+(T A, isl::val v) -> decltype(A.add_constant(v)) { + return A.add_constant(v); } inline isl::aff operator+(isl::val v, isl::aff A) { return A + v; } -inline isl::aff operator+(isl::aff A, int i) { +template +inline auto operator+(T A, int i) -> decltype(A.add_constant_si(i)) { return i + A; } -inline isl::aff operator-(isl::aff A, int i) { +template +inline auto operator-(T A, int i) -> decltype(A.add_constant_si(i)) { return A + (-i); } -inline isl::aff operator-(int i, isl::aff A) { - return (A + (-i)).neg(); +template +inline auto operator-(int i, T A) -> decltype(A.add_constant_si(i)) { + return i + A.neg(); } inline isl::set operator>=(isl::aff_set A, isl::val v) { @@ -184,7 +189,8 @@ inline isl::map operator<=(isl::aff_map A, isl::aff B) { /////////////////////////////////////////////////////////////////////////////// // Operations on isl::multi_aff /////////////////////////////////////////////////////////////////////////////// -inline isl::multi_aff operator/(isl::multi_aff left, isl::multi_val right) { +template +inline auto operator/(S left, T right) -> decltype(left.scale_down(right)) { return left.scale_down(right); } diff --git a/tc/external/detail/islpp.h b/tc/external/detail/islpp.h index 96fa168ff..114646ff0 100644 --- a/tc/external/detail/islpp.h +++ b/tc/external/detail/islpp.h @@ -45,6 +45,11 @@ inline T operator-(T a, T b) { return a.sub(b); } +template +inline auto operator-(isl::val a, T b) -> decltype(b.add_constant(a)) { + return b.neg().add_constant(a); +} + template inline T operator&(T S1, T S2) { return S1.intersect(S2); @@ -123,18 +128,11 @@ inline bool operator!=(isl::val v1, isl::val v2) { /////////////////////////////////////////////////////////////////////////////// isl::aff operator*(int i, isl::aff A); isl::aff operator*(isl::aff A, int i); -isl::aff operator*(isl::aff A, isl::val V); -isl::aff operator*(isl::val V, isl::aff A); isl::aff operator/(isl::aff A, int i); -isl::aff operator+(isl::aff A, int i); -isl::aff operator+(isl::aff A, isl::val v); isl::aff operator+(isl::val v, isl::aff A); -isl::aff operator-(isl::aff A, int i); -isl::aff operator-(int i, isl::aff A); - // Thin wrapper around aff to disambiguate types for operators and avoid case // where return type overloading occurs struct aff_set { @@ -180,11 +178,6 @@ isl::map operator<=(isl::aff_map A, isl::aff B); isl::map operator>(isl::aff_map A, isl::aff B); isl::map operator<(isl::aff_map A, isl::aff B); -/////////////////////////////////////////////////////////////////////////////// -// Operations on isl::multi_aff -/////////////////////////////////////////////////////////////////////////////// -isl::multi_aff operator/(isl::multi_aff left, isl::multi_val right); - /////////////////////////////////////////////////////////////////////////////// // Operations on isl::set and isl::union_set /////////////////////////////////////////////////////////////////////////////// @@ -264,6 +257,8 @@ struct UnionAsVector } }; +#include + struct IslIdIslHash { size_t operator()(const isl::id& id) const { return isl_id_get_hash(id.get()); @@ -284,9 +279,11 @@ inline bool operator!=(const isl::id& id1, const isl::id& id2) { /////////////////////////////////////////////////////////////////////////////// // Given a space and a list of values, this returns the corresponding multi_val. -template -isl::multi_val makeMultiVal(isl::space s, const std::vector& vals) { - isl::multi_val mv = isl::multi_val::zero(s); +template +isl::MultiVal makeMultiVal( + isl::Space s, + const std::vector& vals) { + isl::MultiVal mv = isl::MultiVal::zero(s); TC_CHECK_EQ(vals.size(), static_cast(mv.size())); for (size_t i = 0; i < vals.size(); ++i) { mv = mv.set_val(i, isl::val(s.get_ctx(), vals[i])); @@ -301,17 +298,17 @@ isl::multi_val makeMultiVal(isl::space s, const std::vector& vals) { // 2. each new parameter dimension p(i) is bounded to be in [0, e(i) - 1] // 3. if e (i) == 0 then no constraint is set on the corresponding id(i) template -inline isl::set makeParameterContext( - isl::space paramSpace, +inline isl::Set<> makeParameterContext( + isl::Space<> paramSpace, const IterPair begin, const IterPair end) { for (auto it = begin; it != end; ++it) { paramSpace = paramSpace.add_param(it->first); } - isl::set res(isl::set::universe(paramSpace)); + auto res(isl::Set<>::universe(paramSpace)); for (auto it = begin; it != end; ++it) { - isl::aff a(isl::aff::param_on_domain_space(paramSpace, it->first)); - res = res & (isl::aff_set(a) >= 0) & (isl::aff_set(a) < it->second); + auto a(isl::AffOn<>::param_on_domain_space(paramSpace, it->first)); + res = res & a.asPwAff().nonneg_set() & (it->second - a).asPwAff().pos_set(); } return res; } @@ -320,18 +317,18 @@ inline isl::set makeParameterContext( // that ties the space parameter to the values. // template -inline isl::set makeSpecializationSet( - isl::space space, +inline isl::Set<> makeSpecializationSet( + isl::Space<> space, const std::unordered_map& paramValues) { auto ctx = space.get_ctx(); for (auto kvp : paramValues) { space = space.add_param(isl::id(ctx, kvp.first)); } - auto set = isl::set::universe(space); + auto set = isl::Set<>::universe(space); for (auto kvp : paramValues) { auto id = isl::id(ctx, kvp.first); - isl::aff affParam(isl::aff::param_on_domain_space(space, id)); - set = set & (isl::aff_set(affParam) == kvp.second); + auto affParam(isl::AffOn<>::param_on_domain_space(space, id)); + set = set & (affParam - kvp.second).asPwAff().zero_set(); } return set; } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5f18be9c6..f97eeb01f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -13,9 +13,7 @@ set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) add_executable(test_basic test_basic.cc) add_test(test_basic test_basic) target_link_libraries(test_basic ${GOOGLE_LIBRARIES} ${ISL_LIBRARIES} ${ATEN_LIBRARIES} pthread) -if (WITH_BINDINGS) - add_dependencies(test_basic generate_isl_cpp_h) -endif() +add_dependencies(test_basic generate_isl_cpp_h) ################################################################################ # Core library only tests diff --git a/test/test_cuda_mapper_memory_promotion.cc b/test/test_cuda_mapper_memory_promotion.cc index 0fb7405b8..8cf0bc05f 100644 --- a/test/test_cuda_mapper_memory_promotion.cc +++ b/test/test_cuda_mapper_memory_promotion.cc @@ -64,7 +64,7 @@ class TestMapper : public ::testing::Test { TensorGroups accessedBySubtree( const polyhedral::detail::ScheduleTree* tree, const Scop& scop) { - auto schedule = partialSchedule(scop.scheduleRoot(), tree); + auto schedule = partialSchedule(scop.scheduleRoot(), tree); return TensorReferenceGroup::accessedWithin(schedule, scop.body); } }; @@ -309,7 +309,7 @@ def fun(float(N, M) A, float(N, M) B) -> (C) { EXPECT_EQ( np, std::min(tile1, problemSize1) * std::min(tile2, problemSize2)); - auto schedule = partialSchedule( + auto schedule = partialSchedule( scop.scheduleRoot(), scop.scheduleRoot()->child(childPos)); auto scopedAccess = oneGroup->originalAccesses().apply_domain(schedule); TC_CHECK(scopedAccess.is_equal(oneGroup->scopedAccesses())) @@ -376,7 +376,7 @@ def fun(float(N, M) A) -> (B, C) { auto active = activeDomainPoints(scop.scheduleRoot(), t); LOG(INFO) << "Active: " << active; - auto schedule = partialSchedule(scop.scheduleRoot(), t); + auto schedule = partialSchedule(scop.scheduleRoot(), t); auto scopedAccess = groupsB[0]->originalAccesses().apply_domain(schedule); TC_CHECK(scopedAccess.is_equal(groupsB[0]->scopedAccesses())) << "expected original accesses " << groupsB[0]->originalAccesses() diff --git a/third-party/islpp b/third-party/islpp index 35748a0c1..e1f96fb82 160000 --- a/third-party/islpp +++ b/third-party/islpp @@ -1 +1 @@ -Subproject commit 35748a0c1fc63ea3ba7296f8b4e40346426f53b3 +Subproject commit e1f96fb82746f01cad0f345fa105dafe255df80e