diff --git a/tc/core/halide2isl.cc b/tc/core/halide2isl.cc index f5e1ed753..e2ae351aa 100644 --- a/tc/core/halide2isl.cc +++ b/tc/core/halide2isl.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include "tc/core/constants.h" @@ -238,7 +239,20 @@ isl::set makeParamContext(isl::ctx ctx, const SymbolTable& symbolTable) { return context; } -isl::map extractAccess( +// Extract a tagged affine access relation from Halide IR. +// The relation is tagged with a unique identifier, i.e. it lives in the space +// [D[...] -> __tc_ref_#[]] -> A[] +// where # is a unique sequential number, D is the statement identifier +// extracted from "domain" and A is the tensor identifier constructed from +// "tensor". "accesses" map is updated to keep track of the Halide IR nodes in +// which a particular reference # appeared. +// Returns the access relation and a flag indicating whether this relation is +// exact or not. The relation is overapproximated (that is, not exact) if it +// represents a non-affine access, for example, an access with indirection such +// as O(Index(i)) = 42. In such overapproximated access relation, dimensions +// that correspond to affine subscripts are still exact while those that +// correspond to non-affine subscripts are not constrained. +std::pair extractAccess( isl::set domain, const IRNode* op, const std::string& tensor, @@ -267,6 +281,7 @@ isl::map extractAccess( isl::map map = isl::map::universe(domainSpace.map_from_domain_and_range(rangeSpace)); + bool exact = true; for (size_t i = 0; i < args.size(); i++) { // Then add one equality constraint per dimension to encode the // point in the allocation actually read/written for each point in @@ -278,15 +293,17 @@ isl::map extractAccess( isl::pw_aff(isl::local_space(rangeSpace), isl::dim_type::set, i); // ... equals the coordinate accessed as a function of the domain. auto domainPoint = halide2isl::makeIslAffFromExpr(domainSpace, args[i]); - if (!domainPoint.is_null()) { + if (!domainPoint) { + exact = false; + } else { map = map.intersect(isl::pw_aff(domainPoint).eq_map(rangePoint)); } } - return map; + return std::make_pair(map, exact); } -std::pair +std::tuple extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) { class FindAccesses : public IRGraphVisitor { using IRGraphVisitor::visit; @@ -294,31 +311,46 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) { void visit(const Call* op) override { IRGraphVisitor::visit(op); if (op->call_type == Call::Halide || op->call_type == Call::Image) { - reads = reads.unite( - extractAccess(domain, op, op->name, op->args, accesses)); + // Read relations can be safely overapproximated. + isl::map read; + std::tie(read, std::ignore) = + extractAccess(domain, op, op->name, op->args, accesses); + reads = reads.unite(read); } } void visit(const Provide* op) override { IRGraphVisitor::visit(op); - writes = - writes.unite(extractAccess(domain, op, op->name, op->args, accesses)); + + // If the write access relation is not exact, we consider that any + // element _may_ be written by the statement. If it is exact, then we + // can guarantee that all the elements specified by the relation _must_ + // be written and any previously stored value will be killed. + isl::map write; + bool exact; + std::tie(write, exact) = + extractAccess(domain, op, op->name, op->args, accesses); + if (exact) { + mustWrites = mustWrites.unite(write); + } + mayWrites = mayWrites.unite(write); } const isl::set& domain; AccessMap* accesses; public: - isl::union_map reads, writes; + isl::union_map reads, mayWrites, mustWrites; FindAccesses(const isl::set& domain, AccessMap* accesses) : domain(domain), accesses(accesses), reads(isl::union_map::empty(domain.get_space())), - writes(isl::union_map::empty(domain.get_space())) {} + mayWrites(isl::union_map::empty(domain.get_space())), + mustWrites(isl::union_map::empty(domain.get_space())) {} } finder(domain, accesses); s.accept(&finder); - return {finder.reads, finder.writes}; + return std::make_tuple(finder.reads, finder.mayWrites, finder.mustWrites); } /* @@ -343,7 +375,8 @@ isl::schedule makeScheduleTreeHelper( isl::set set, std::vector& outer, isl::union_map* reads, - isl::union_map* writes, + isl::union_map* mayWrites, + isl::union_map* mustWrites, AccessMap* accesses, StatementMap* statements, IteratorMap* iterators) { @@ -389,7 +422,8 @@ isl::schedule makeScheduleTreeHelper( set, outerNext, reads, - writes, + mayWrites, + mustWrites, accesses, statements, iterators); @@ -422,7 +456,15 @@ isl::schedule makeScheduleTreeHelper( std::vector schedules; for (Stmt s : stmts) { schedules.push_back(makeScheduleTreeHelper( - s, set, outer, reads, writes, accesses, statements, iterators)); + s, + set, + outer, + reads, + mayWrites, + mustWrites, + accesses, + statements, + iterators)); } schedule = schedules[0].sequence(schedules[1]); @@ -437,23 +479,25 @@ isl::schedule makeScheduleTreeHelper( isl::set domain = set.set_tuple_id(id); schedule = isl::schedule::from_domain(domain); - isl::union_map newReads, newWrites; - std::tie(newReads, newWrites) = + isl::union_map newReads, newMayWrites, newMustWrites; + std::tie(newReads, newMayWrites, newMustWrites) = halide2isl::extractAccesses(domain, op, accesses); *reads = reads->unite(newReads); - *writes = writes->unite(newWrites); + *mayWrites = mayWrites->unite(newMayWrites); + *mustWrites = mustWrites->unite(newMustWrites); } else { LOG(FATAL) << "Unhandled Halide stmt: " << s; } return schedule; -}; +} ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) { ScheduleTreeAndAccesses result; - result.writes = result.reads = isl::union_map::empty(paramSpace); + result.mayWrites = result.mustWrites = result.reads = + isl::union_map::empty(paramSpace); // Walk the IR building a schedule tree std::vector outer; @@ -462,7 +506,8 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) { isl::set::universe(paramSpace), outer, &result.reads, - &result.writes, + &result.mayWrites, + &result.mustWrites, &result.accesses, &result.statements, &result.iterators); diff --git a/tc/core/halide2isl.h b/tc/core/halide2isl.h index 74ab23f88..bd771c01f 100644 --- a/tc/core/halide2isl.h +++ b/tc/core/halide2isl.h @@ -70,7 +70,7 @@ struct ScheduleTreeAndAccesses { /// Union maps describing the reads and writes done. Uses the ids in /// the schedule tree to denote the containing Stmt, and tags each /// access with a unique reference id of the form __tc_ref_N. - isl::union_map reads, writes; + isl::union_map reads, mayWrites, mustWrites; /// The correspondence between from Call and Provide nodes and the /// reference ids in the reads and writes maps. diff --git a/tc/core/libraries.h b/tc/core/libraries.h index 9d8386cd0..87cc0ee5f 100644 --- a/tc/core/libraries.h +++ b/tc/core/libraries.h @@ -32,8 +32,14 @@ namespace c { constexpr auto types = R"C( // Halide type handling -typedef int int32; -typedef long int64; +typedef signed char int8; +typedef unsigned char uint8; +typedef signed short int16; +typedef unsigned short uint16; +typedef signed int int32; +typedef unsigned int uint32; +typedef signed long int64; +typedef unsigned long uint64; typedef float float32; typedef double float64; )C"; @@ -81,16 +87,16 @@ float fmodf ( float x, float y ); //float frexpf ( float x, int* nptr ); float hypotf ( float x, float y ); //int ilogbf ( float x ); -//__RETURN_TYPE isfinite ( float a ); -//__RETURN_TYPE isinf ( float a ); -//__RETURN_TYPE isnan ( float a ); +//__RETURN_TYPE isfinite ( float a ); +//__RETURN_TYPE isinf ( float a ); +//__RETURN_TYPE isnan ( float a ); float j0f ( float x ); float j1f ( float x ); //float jnf ( int n, float x ); //float ldexpf ( float x, int exp ); float lgammaf ( float x ); -//long long int llrintf ( float x ); -//long long int llroundf ( float x ); +//long long int llrintf ( float x ); +//long long int llroundf ( float x ); float log10f ( float x ); float log1pf ( float x ); float log2f ( float x ); @@ -120,7 +126,7 @@ float roundf ( float x ); float rsqrtf ( float x ); //float scalblnf ( float x, long int n ); //float scalbnf ( float x, int n ); -//__RETURN_TYPE signbit ( float a ); +//__RETURN_TYPE signbit ( float a ); //void sincosf ( float x, float* sptr, float* cptr ); //void sincospif ( float x, float* sptr, float* cptr ); float sinf ( float x ); diff --git a/tc/core/polyhedral/memory_promotion.cc b/tc/core/polyhedral/memory_promotion.cc index f12a60b89..5896dbace 100644 --- a/tc/core/polyhedral/memory_promotion.cc +++ b/tc/core/polyhedral/memory_promotion.cc @@ -348,7 +348,7 @@ TensorGroups TensorReferenceGroup::accessedBySubtree( auto schedule = partialSchedule(scop.scheduleRoot(), tree); addSingletonReferenceGroups( - tensorGroups, scop.writes, domain, schedule, AccessType::Write); + tensorGroups, scop.mayWrites, domain, schedule, AccessType::Write); addSingletonReferenceGroups( tensorGroups, scop.reads, domain, schedule, AccessType::Read); diff --git a/tc/core/polyhedral/scop.cc b/tc/core/polyhedral/scop.cc index 54a968556..0c2335ca8 100644 --- a/tc/core/polyhedral/scop.cc +++ b/tc/core/polyhedral/scop.cc @@ -61,7 +61,8 @@ ScopUPtr Scop::makeScop( auto tree = halide2isl::makeScheduleTree(paramSpace, components.stmt); scop->scheduleTreeUPtr = std::move(tree.tree); scop->reads = tree.reads; - scop->writes = tree.writes; + scop->mayWrites = tree.mayWrites; + scop->mustWrites = tree.mustWrites; scop->halide.statements = std::move(tree.statements); scop->halide.accesses = std::move(tree.accesses); scop->halide.reductions = halide2isl::findReductions(components.stmt); @@ -109,7 +110,8 @@ const isl::union_set Scop::domain() const { std::ostream& operator<<(std::ostream& os, const Scop& s) { os << "domain: " << s.domain() << "\n"; os << "reads: " << s.reads << "\n"; - os << "writes: " << s.writes << "\n"; + os << "mayWrites: " << s.mayWrites << "\n"; + os << "mustWrites: " << s.mustWrites << "\n"; os << "schedule: " << *s.scheduleRoot() << "\n"; os << "idx: { "; for (auto i : s.halide.idx) { @@ -351,19 +353,29 @@ namespace { using namespace tc::polyhedral; +// Compute the dependence using the given may/must sources, sinks and kills. +// Any of the inputs may be an empty (but non-null) union map. +// Dependence analysis removes the cases transitively covered by a must source +// or a kill. isl::union_map computeDependences( - isl::union_map sources, + isl::union_map maySources, + isl::union_map mustSources, isl::union_map sinks, + isl::union_map kills, isl::schedule schedule) { auto uai = isl::union_access_info(sinks); - uai = uai.set_may_source(sources); + uai = uai.set_may_source(maySources); + uai = uai.set_must_source(mustSources); + uai = uai.set_kill(kills); uai = uai.set_schedule(schedule); auto flow = uai.compute_flow(); return flow.get_may_dependence(); } -// Do the simplest possible dependence analysis. -// Live-range reordering needs tagged access relations to be available. +// Set up schedule constraints by performing the dependence analysis using +// access relations from "scop". Set up callbacks in the constraints depending +// on "scheduleOptions". +// // The domain of the constraints is intersected with "restrictDomain" if it is // provided. isl::schedule_constraints makeScheduleConstraints( @@ -373,12 +385,16 @@ isl::schedule_constraints makeScheduleConstraints( auto schedule = toIslSchedule(scop.scheduleRoot()); auto firstChildNode = scop.scheduleRoot()->child({0}); auto reads = scop.reads.domain_factor_domain(); - auto writes = scop.writes.domain_factor_domain(); + auto mayWrites = scop.mayWrites.domain_factor_domain(); + auto mustWrites = scop.mustWrites.domain_factor_domain(); + auto empty = isl::union_map::empty(mustWrites.get_space()); // RAW - auto flowDeps = computeDependences(writes, reads, schedule); + auto flowDeps = + computeDependences(mayWrites, mustWrites, reads, mustWrites, schedule); // WAR and WAW - auto falseDeps = computeDependences(writes.unite(reads), writes, schedule); + auto falseDeps = computeDependences( + mayWrites.unite(reads), empty, mayWrites, mustWrites, schedule); auto allDeps = flowDeps.unite(falseDeps).coalesce(); diff --git a/tc/core/polyhedral/scop.h b/tc/core/polyhedral/scop.h index bc4953b1b..918d0ee4f 100644 --- a/tc/core/polyhedral/scop.h +++ b/tc/core/polyhedral/scop.h @@ -66,7 +66,8 @@ struct Scop { res->globalParameterContext = scop.globalParameterContext; res->halide = scop.halide; res->reads = scop.reads; - res->writes = scop.writes; + res->mayWrites = scop.mayWrites; + res->mustWrites = scop.mustWrites; res->scheduleTreeUPtr = detail::ScheduleTree::makeScheduleTree(*scop.scheduleTreeUPtr); res->treeSyncUpdateMap = scop.treeSyncUpdateMap; @@ -115,7 +116,8 @@ struct Scop { void specializeToContext() { domain() = domain().intersect_params(globalParameterContext); reads = reads.intersect_params(globalParameterContext); - writes = writes.intersect_params(globalParameterContext); + mayWrites = mayWrites.intersect_params(globalParameterContext); + mustWrites = mustWrites.intersect_params(globalParameterContext); } // Returns a set that specializes the named scop's subset of @@ -442,8 +444,14 @@ struct Scop { // This globalParameterContext lives in a parameter space. isl::set globalParameterContext; // TODO: not too happy about this name + // Access relations. + // Elements in mayWrite may or may not be written by the execution, depending + // on some dynamic condition. Those in mustWrites are always written. + // Thefore, mayWrites do not participate in transitively-covered dependence + // removal. isl::union_map reads; - isl::union_map writes; + isl::union_map mayWrites; + isl::union_map mustWrites; private: // By analogy with generalized functions, a ScheduleTree is a (piecewise diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 7286c9509..9119190ff 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -216,7 +216,7 @@ Expr translateExpr( } } -vector unboundVariables(const vector& lhs, Expr rhs) { +vector unboundVariables(const vector& lhs, Expr rhs) { class FindUnboundVariables : public IRVisitor { using IRVisitor::visit; @@ -241,14 +241,19 @@ vector unboundVariables(const vector& lhs, Expr rhs) { set visited; public: - FindUnboundVariables(const vector& lhs) { - for (auto v : lhs) { - bound.push(v.name()); + FindUnboundVariables(const vector& lhs) { + for (auto e : lhs) { + if (const Variable* v = e.as()) { + bound.push(v->name); + } } } vector result; } finder(lhs); rhs.accept(&finder); + for (auto e : lhs) { + e.accept(&finder); + } return finder.result; } @@ -507,22 +512,31 @@ void translateComprehension( f = Function(c.ident().name()); (*funcs)[c.ident().name()] = f; } + + // we currently inline all of the let bindings generated in where clauses + // in the future we may consider using Halide Let bindings when they + // are supported later + map lets; + // Function is the internal Halide IR type for a pipeline // stage. Func is the front-end class that wraps it. Here it's // convenient to use both. Func func(f); - vector lhs; - vector lhs_as_exprs; - for (lang::Ident id : c.indices()) { - lhs.push_back(Var(id.name())); - lhs_as_exprs.push_back(lhs.back()); + vector lhs; + vector lhs_vars; + bool total_definition = true; + for (lang::TreeRef idx : c.indices()) { + Expr e = translateExpr(idx, params, *funcs, lets); + if (const Variable* op = e.as()) { + lhs_vars.push_back(Var(op->name)); + } else { + total_definition = false; + lhs_vars.push_back(Var()); + } + lhs.push_back(e); } - // we currently inline all of the let bindings generated in where clauses - // in the future we may consider using Halide Let bindings when they - // are supported later - map lets; for (auto wc : c.whereClauses()) { if (wc->kind() == lang::TK_LET) { auto let = lang::Let(wc); @@ -546,9 +560,8 @@ void translateComprehension( auto setupIdentity = [&](const Expr& identity, bool zero) { if (!f.has_pure_definition()) { added_implicit_initialization = true; - func(lhs) = (zero) ? identity - : undef(rhs.type()); // undef causes the original value - // to remain in input arrays + // undef causes the original value to remain in input arrays + func(lhs_vars) = (zero) ? identity : undef(rhs.type()); } }; @@ -587,6 +600,9 @@ void translateComprehension( break; case '=': + if (!total_definition) { + setupIdentity(rhs, false); + } break; default: throw lang::ErrorReport(c) << "Unimplemented reduction " @@ -618,9 +634,10 @@ void translateComprehension( for (auto& exp : all_exprs) { exp = bindParams.mutate(exp); } - - // TODO: When the LHS incorporates general expressions we'll need to - // bind params there too. + for (auto& e : lhs) { + e = bindParams.mutate(e); + all_exprs.push_back(e); + } // Do forward bounds inference -- construct an expression that says // this expression never reads out of bounds on its inputs, and @@ -660,19 +677,34 @@ void translateComprehension( // (e.g. an in-place stencil)?. The .bound directive will use the // bounds of the last stage for all stages. - // Does a tensor have a single bound, or can its bounds shrink over - // time? Solve for a single bound for now. + // Set the bounds to be the union of the boxes written to by every + // comprehension touching the tensor. + for (size_t i = 0; i < lhs.size(); i++) { + Expr e = lhs[i]; + if (const Variable* v = e.as()) { + if (!solution.contains(v->name)) { + throw lang::ErrorReport(c) + << "Free variable " << v + << " was not solved in range inference. May not be used right-hand side"; + } + } - for (Var v : lhs) { - if (!solution.contains(v.name())) { - throw lang::ErrorReport(c) - << "Free variable " << v - << " was not solved in range inference. May not be used right-hand side"; + Interval in = bounds_of_expr_in_scope(e, solution); + if (!in.is_bounded()) { + throw lang::ErrorReport(c.indices()[i]) + << "Left-hand side expression is unbounded"; } - // TODO: We're enforcing a single bound across all comprehensions - // for now. We should really check later ones are equal to earlier - // ones instead of just clobbering. - (*bounds)[f][v.name()] = solution.get(v.name()); + in.min = cast(in.min); + in.max = cast(in.max); + + map& b = (*bounds)[f]; + string dim_name = f.dimensions() ? f.args()[i] : lhs_vars[i].name(); + auto old = b.find(dim_name); + if (old != b.end()) { + // Take the union with any existing bounds + in.include(old->second); + } + b[dim_name] = in; } // Free variables that appear on the rhs but not the lhs are @@ -703,6 +735,9 @@ void translateComprehension( for (auto v : unbound) { Expr rv = Variable::make(Int(32), v->name, domain); rhs = substitute(v->name, rv, rhs); + for (Expr& e : lhs) { + e = substitute(v->name, rv, e); + } } rdom = RDom(domain); } @@ -718,9 +753,12 @@ void translateComprehension( } } while (!lhs.empty()) { - loop_nest.push_back(lhs.back()); + if (const Variable* v = lhs.back().as()) { + loop_nest.push_back(Var(v->name)); + } lhs.pop_back(); } + stage.reorder(loop_nest); if (added_implicit_initialization) { // Also reorder reduction initializations to the TC convention @@ -734,7 +772,6 @@ void translateComprehension( } func.compute_root(); - stage.reorder(loop_nest); } HalideComponents translateDef(const lang::Def& def, bool throwWarnings) { diff --git a/tc/lang/parser.h b/tc/lang/parser.h index 4083771f7..ceaffd32f 100644 --- a/tc/lang/parser.h +++ b/tc/lang/parser.h @@ -151,6 +151,15 @@ struct Parser { TreeRef parseExpList() { return parseList('(', ',', ')', [&](int i) { return parseExp(); }); } + TreeRef parseOptionalExpList() { + TreeRef list = nullptr; + if (L.cur().kind == '(') { + list = parseExpList(); + } else { + list = List::create(L.cur().range, {}); + } + return list; + } TreeRef parseIdentList() { return parseList('(', ',', ')', [&](int i) { return parseIdent(); }); } @@ -226,7 +235,7 @@ struct Parser { } TreeRef parseStmt() { auto ident = parseIdent(); - TreeRef list = parseOptionalIdentList(); + TreeRef list = parseOptionalExpList(); auto assign = parseAssignment(); auto rhs = parseExp(); TreeRef equivalent_statement = parseEquivalent(); diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 406a82711..34a9e7732 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -437,15 +437,35 @@ struct Sema { return checkRangeConstraint(RangeConstraint(ref)); } } + + private: + // Traverse the list of trees, recursively descending into arguments of APPLY + // and ACCESS subtrees and into all subtrees of different types (mostly + // expressions), and collect names and types of IDENT subtrees in + // "index_env". Expects to be called on the indices of the LHS tensor. + template + void registerLHSIndices(const Collection& treeRefs) { + for (const auto& treeRef : treeRefs) { + if (treeRef->kind() == TK_IDENT) { + std::string idx = Ident(treeRef).name(); + auto typ = indexType(treeRef); + insert(index_env, Ident(treeRef), typ, true); + } else if (treeRef->kind() == TK_APPLY) { + registerLHSIndices(Apply(treeRef).arguments()); + } else if (treeRef->kind() == TK_ACCESS) { + registerLHSIndices(Access(treeRef).arguments()); + } else { + registerLHSIndices(treeRef->trees()); + } + } + } + + public: TreeRef checkStmt(TreeRef stmt_) { auto stmt = Comprehension(stmt_); // register index variables (non-reductions) - for (const auto& index : stmt.indices()) { - std::string idx = index.name(); - auto typ = indexType(index); - insert(index_env, index, typ, true); - } + registerLHSIndices(stmt.indices()); // make dimension variables for each dimension of the output tensor std::string name = stmt.ident().name(); @@ -462,9 +482,13 @@ struct Sema { // where clauses are checked _before_ the rhs because they // introduce let bindings that are in scope for the rhs + // auto where_clauses_ = stmt.whereClauses().map( [&](TreeRef rc) { return checkWhereClause(rc); }); + auto indices_ = + stmt.indices().map([&](TreeRef idx) { return checkExp(idx, true); }); + TreeRef rhs_ = checkExp(stmt.rhs(), true); TreeRef scalar_type = typeOfExpr(rhs_); @@ -525,7 +549,7 @@ struct Sema { TreeRef result = Comprehension::create( stmt.range(), stmt.ident(), - stmt.indices(), + indices_, stmt.assignment(), rhs_, where_clauses_, diff --git a/tc/lang/tc_format.cc b/tc/lang/tc_format.cc index 8f1fbe8f1..55457d55a 100644 --- a/tc/lang/tc_format.cc +++ b/tc/lang/tc_format.cc @@ -60,8 +60,9 @@ std::ostream& operator<<(std::ostream& s, const Param& p) { } std::ostream& operator<<(std::ostream& s, const Comprehension& comp) { - s << comp.ident() << "(" << comp.indices() << ") " - << kindToToken(comp.assignment()->kind()) << " "; + s << comp.ident() << "("; + showList(s, comp.indices(), showExpr); + s << ") " << kindToToken(comp.assignment()->kind()) << " "; showExpr(s, comp.rhs()); if (!comp.whereClauses().empty()) throw std::runtime_error("Printing of where clauses is not supported yet"); diff --git a/tc/lang/tree_views.h b/tc/lang/tree_views.h index 1e26b8437..099b4c458 100644 --- a/tc/lang/tree_views.h +++ b/tc/lang/tree_views.h @@ -386,8 +386,8 @@ struct Comprehension : public TreeView { Ident ident() const { return Ident(subtree(0)); } - ListView indices() const { - return ListView(subtree(1)); + ListView indices() const { + return ListView(subtree(1)); } // kind == '=', TK_PLUS_EQ, TK_PLUS_EQ_B, etc. TreeRef assignment() const { diff --git a/test/cuda/test_execution_engine.cc b/test/cuda/test_execution_engine.cc index cd508ae8c..3ca4558eb 100644 --- a/test/cuda/test_execution_engine.cc +++ b/test/cuda/test_execution_engine.cc @@ -145,6 +145,25 @@ def concat(float(M, N) A, float(M, N) B) -> (O1) { outputs); } +TEST_F(ATenCompilationUnitTest, Concat2) { + at::Tensor a = at::CUDA(at::kFloat).rand({32, 16}); + at::Tensor b = at::CUDA(at::kFloat).rand({32, 16}); + std::vector inputs = {a, b}; + std::vector outputs; + + Check( + R"( +def concat(float(M, N) A, float(M, N) B) -> (O1) { + O1(n, 0, m) = A(m, n) + O1(n, 1, m) = B(m, n) +} + )", + "concat", + tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), + inputs, + outputs); +} + TEST_F(ATenCompilationUnitTest, Indexing) { at::Tensor a = at::CUDA(at::kFloat).rand({3, 4}); at::Tensor b = at::CUDA(at::kInt).ones({2}); diff --git a/test/cuda/test_tc_mapper.cc b/test/cuda/test_tc_mapper.cc index d005e4db3..6a81e9055 100644 --- a/test/cuda/test_tc_mapper.cc +++ b/test/cuda/test_tc_mapper.cc @@ -352,6 +352,49 @@ def fun(float(B, R) LUT, int32(B, N) I) -> (O) { checkFun); } +TEST_F(TcCudaMapperTest, Histogram) { + const int N = 17, M = 82; + at::Tensor I = + at::CUDA(at::kFloat).rand({N, M}).mul_(256).floor_().toType(at::kByte); + std::vector inputs = {I}; + std::vector outputs; + + static constexpr auto TC = R"TC( +def fun(uint8(N, M) I) -> (O) { + O(I(i, j)) +=! 1 +} +)TC"; + + auto checkFun = [=](const std::vector& inputs, + std::vector& outputs) { + at::Tensor I = inputs[0].toBackend(at::kCPU); + at::Tensor O = outputs[0].toBackend(at::kCPU); + auto IAccessor = I.accessor(); + auto OAccessor = O.accessor(); + int sum = 0; + for (int i = 0; i < 256; i++) { + sum += OAccessor[i]; + } + CHECK_EQ(sum, N * M); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < M; j++) { + OAccessor[IAccessor[i][j]]--; + } + } + + for (int i = 0; i < 256; i++) { + CHECK_EQ(OAccessor[i], 0); + } + }; + Check( + TC, + "fun", + tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), + inputs, + checkFun); +} + /////////////////////////////////////////////////////////////////////////////// // SpatialBatchNormalization /////////////////////////////////////////////////////////////////////////////// diff --git a/test/test_core.cc b/test/test_core.cc index 022ef1020..2896e090b 100644 --- a/test/test_core.cc +++ b/test/test_core.cc @@ -165,6 +165,10 @@ struct TC2Isl : public ::testing::Test { auto scheduleHalide = polyhedral::detail::fromIslSchedule( polyhedral::detail::toIslSchedule(scop->scheduleRoot()).reset_user()); } + + std::unique_ptr MakeScop(const std::string& tc) { + return polyhedral::Scop::makeScop(isl::with_exceptions::globalIslCtx(), tc); + } }; TEST_F(TC2Isl, Copy1D) { @@ -313,6 +317,69 @@ def fun(float(M, N) I) -> (O1, O2, O3) { Check(tc, {123, 13}); } +// FIXME: range inference seems unaware of indirections on the LHS +TEST_F(TC2Isl, DISABLED_MayWritesOnly) { + string tc = R"TC( +def scatter(int32(N) A, int32(M) B) -> (O) { + O(A(i)) = B(i) +} +)TC"; + auto scop = MakeScop(tc); + CHECK(scop->mustWrites.is_empty()) + << "expected empty must-writes for scatter, got\n" + << scop->mustWrites; + CHECK(!scop->mayWrites.is_empty()) + << "expected non-empty may-writes for scatter, got\n" + << scop->mayWrites; +} + +TEST_F(TC2Isl, AllMustWrites) { + string tc = R"TC( +def gather(int32(N) A, int32(N) B) -> (O) { + O(i) = A(B(i)) where i in 0:N +} +)TC"; + auto scop = MakeScop(tc); + CHECK(!scop->mustWrites.is_empty()) + << "expected non-empty must-writes for gather, got\n" + << scop->mustWrites; + CHECK_EQ(scop->mustWrites, scop->mayWrites); +} + +TEST_F(TC2Isl, Computed) { + string tc = R"TC( +def gather(int32(N) A, int32(N) B) -> (O) { + O(i - 2) = A(i) + B(i) +} +)TC"; + auto scop = MakeScop(tc); + CHECK(!scop->mustWrites.is_empty()) + << "expected non-empty must-writes for gather, got\n" + << scop->mustWrites; + CHECK_EQ(scop->mustWrites, scop->mayWrites); +} + +// FIXME: range inference seems unaware of indirections on the LHS +TEST_F(TC2Isl, DISABLED_MustWritesSubsetMayWrites) { + string tc = R"TC( +def scatter_gather(int32(N) A, int32(N) B) -> (O1,O2) { + O1(i) = A(B(i)) where i in 0:N + O2(A(i)) = B(i) +} +)TC"; + auto scop = MakeScop(tc); + CHECK(!scop->mustWrites.is_empty()) << "expected non-empty must-writes, got\n" + << scop->mustWrites; + CHECK(!scop->mayWrites.is_empty()) << "expected non-empty may-writes, got\n" + << scop->mustWrites; + CHECK(scop->mustWrites.is_subset(scop->mayWrites)) + << scop->mustWrites << " is expected to be a subsetset of " + << scop->mayWrites; + CHECK(!scop->mayWrites.subtract(scop->mustWrites).is_empty()) + << scop->mustWrites << "is expected to be a strict subset of " + << scop->mayWrites; +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true);