Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Allow computed expressions on the left-hand-side #262

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 65 additions & 20 deletions tc/core/halide2isl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <algorithm>
#include <numeric>
#include <tuple>
#include <unordered_set>

#include "tc/core/constants.h"
Expand Down Expand Up @@ -238,7 +239,20 @@ isl::set makeParamContext(isl::ctx ctx, const SymbolTable& symbolTable) {
return context;
}

isl::map extractAccess(
// Extract a tagged affine access relation from Halide IR.
// The relation is tagged with a unique identifier, i.e. it lives in the space
// [D[...] -> __tc_ref_#[]] -> A[]
// where # is a unique sequential number, D is the statement identifier
// extracted from "domain" and A is the tensor identifier constructed from
// "tensor". "accesses" map is updated to keep track of the Halide IR nodes in
// which a particular reference # appeared.
// Returns the access relation and a flag indicating whether this relation is
// exact or not. The relation is overapproximated (that is, not exact) if it
// represents a non-affine access, for example, an access with indirection such
// as O(Index(i)) = 42. In such overapproximated access relation, dimensions
// that correspond to affine subscripts are still exact while those that
// correspond to non-affine subscripts are not constrained.
std::pair<isl::map, bool> extractAccess(
isl::set domain,
const IRNode* op,
const std::string& tensor,
Expand Down Expand Up @@ -267,6 +281,7 @@ isl::map extractAccess(
isl::map map =
isl::map::universe(domainSpace.map_from_domain_and_range(rangeSpace));

bool exact = true;
for (size_t i = 0; i < args.size(); i++) {
// Then add one equality constraint per dimension to encode the
// point in the allocation actually read/written for each point in
Expand All @@ -278,47 +293,64 @@ isl::map extractAccess(
isl::pw_aff(isl::local_space(rangeSpace), isl::dim_type::set, i);
// ... equals the coordinate accessed as a function of the domain.
auto domainPoint = halide2isl::makeIslAffFromExpr(domainSpace, args[i]);
if (!domainPoint.is_null()) {
if (!domainPoint) {
exact = false;
} else {
map = map.intersect(isl::pw_aff(domainPoint).eq_map(rangePoint));
}
}

return map;
return std::make_pair(map, exact);
}

std::pair<isl::union_map, isl::union_map>
std::tuple<isl::union_map, isl::union_map, isl::union_map>
extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
class FindAccesses : public IRGraphVisitor {
using IRGraphVisitor::visit;

void visit(const Call* op) override {
IRGraphVisitor::visit(op);
if (op->call_type == Call::Halide || op->call_type == Call::Image) {
reads = reads.unite(
extractAccess(domain, op, op->name, op->args, accesses));
// Read relations can be safely overapproximated.
isl::map read;
std::tie(read, std::ignore) =
extractAccess(domain, op, op->name, op->args, accesses);
reads = reads.unite(read);
}
}

void visit(const Provide* op) override {
IRGraphVisitor::visit(op);
writes =
writes.unite(extractAccess(domain, op, op->name, op->args, accesses));

// If the write access relation is not exact, we consider that any
// element _may_ be written by the statement. If it is exact, then we
// can guarantee that all the elements specified by the relation _must_
// be written and any previously stored value will be killed.
isl::map write;
bool exact;
std::tie(write, exact) =
extractAccess(domain, op, op->name, op->args, accesses);
if (exact) {
mustWrites = mustWrites.unite(write);
}
mayWrites = mayWrites.unite(write);
}

const isl::set& domain;
AccessMap* accesses;

public:
isl::union_map reads, writes;
isl::union_map reads, mayWrites, mustWrites;

FindAccesses(const isl::set& domain, AccessMap* accesses)
: domain(domain),
accesses(accesses),
reads(isl::union_map::empty(domain.get_space())),
writes(isl::union_map::empty(domain.get_space())) {}
mayWrites(isl::union_map::empty(domain.get_space())),
mustWrites(isl::union_map::empty(domain.get_space())) {}
} finder(domain, accesses);
s.accept(&finder);
return {finder.reads, finder.writes};
return std::make_tuple(finder.reads, finder.mayWrites, finder.mustWrites);
}

/*
Expand All @@ -343,7 +375,8 @@ isl::schedule makeScheduleTreeHelper(
isl::set set,
std::vector<std::string>& outer,
isl::union_map* reads,
isl::union_map* writes,
isl::union_map* mayWrites,
isl::union_map* mustWrites,
AccessMap* accesses,
StatementMap* statements,
IteratorMap* iterators) {
Expand Down Expand Up @@ -389,7 +422,8 @@ isl::schedule makeScheduleTreeHelper(
set,
outerNext,
reads,
writes,
mayWrites,
mustWrites,
accesses,
statements,
iterators);
Expand Down Expand Up @@ -422,7 +456,15 @@ isl::schedule makeScheduleTreeHelper(
std::vector<isl::schedule> schedules;
for (Stmt s : stmts) {
schedules.push_back(makeScheduleTreeHelper(
s, set, outer, reads, writes, accesses, statements, iterators));
s,
set,
outer,
reads,
mayWrites,
mustWrites,
accesses,
statements,
iterators));
}
schedule = schedules[0].sequence(schedules[1]);

Expand All @@ -437,23 +479,25 @@ isl::schedule makeScheduleTreeHelper(
isl::set domain = set.set_tuple_id(id);
schedule = isl::schedule::from_domain(domain);

isl::union_map newReads, newWrites;
std::tie(newReads, newWrites) =
isl::union_map newReads, newMayWrites, newMustWrites;
std::tie(newReads, newMayWrites, newMustWrites) =
halide2isl::extractAccesses(domain, op, accesses);

*reads = reads->unite(newReads);
*writes = writes->unite(newWrites);
*mayWrites = mayWrites->unite(newMayWrites);
*mustWrites = mustWrites->unite(newMustWrites);

} else {
LOG(FATAL) << "Unhandled Halide stmt: " << s;
}
return schedule;
};
}

ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
ScheduleTreeAndAccesses result;

result.writes = result.reads = isl::union_map::empty(paramSpace);
result.mayWrites = result.mustWrites = result.reads =
isl::union_map::empty(paramSpace);

// Walk the IR building a schedule tree
std::vector<std::string> outer;
Expand All @@ -462,7 +506,8 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
isl::set::universe(paramSpace),
outer,
&result.reads,
&result.writes,
&result.mayWrites,
&result.mustWrites,
&result.accesses,
&result.statements,
&result.iterators);
Expand Down
2 changes: 1 addition & 1 deletion tc/core/halide2isl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct ScheduleTreeAndAccesses {
/// Union maps describing the reads and writes done. Uses the ids in
/// the schedule tree to denote the containing Stmt, and tags each
/// access with a unique reference id of the form __tc_ref_N.
isl::union_map reads, writes;
isl::union_map reads, mayWrites, mustWrites;

/// The correspondence between from Call and Provide nodes and the
/// reference ids in the reads and writes maps.
Expand Down
22 changes: 14 additions & 8 deletions tc/core/libraries.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ namespace c {

constexpr auto types = R"C(
// Halide type handling
typedef int int32;
typedef long int64;
typedef signed char int8;
typedef unsigned char uint8;
typedef signed short int16;
typedef unsigned short uint16;
typedef signed int int32;
typedef unsigned int uint32;
typedef signed long int64;
typedef unsigned long uint64;
typedef float float32;
typedef double float64;
)C";
Expand Down Expand Up @@ -81,16 +87,16 @@ float fmodf ( float x, float y );
//float frexpf ( float x, int* nptr );
float hypotf ( float x, float y );
//int ilogbf ( float x );
//__RETURN_TYPE isfinite ( float a );
//__RETURN_TYPE isinf ( float a );
//__RETURN_TYPE isnan ( float a );
//__RETURN_TYPE isfinite ( float a );
//__RETURN_TYPE isinf ( float a );
//__RETURN_TYPE isnan ( float a );
float j0f ( float x );
float j1f ( float x );
//float jnf ( int n, float x );
//float ldexpf ( float x, int exp );
float lgammaf ( float x );
//long long int llrintf ( float x );
//long long int llroundf ( float x );
//long long int llrintf ( float x );
//long long int llroundf ( float x );
float log10f ( float x );
float log1pf ( float x );
float log2f ( float x );
Expand Down Expand Up @@ -120,7 +126,7 @@ float roundf ( float x );
float rsqrtf ( float x );
//float scalblnf ( float x, long int n );
//float scalbnf ( float x, int n );
//__RETURN_TYPE signbit ( float a );
//__RETURN_TYPE signbit ( float a );
//void sincosf ( float x, float* sptr, float* cptr );
//void sincospif ( float x, float* sptr, float* cptr );
float sinf ( float x );
Expand Down
2 changes: 1 addition & 1 deletion tc/core/polyhedral/memory_promotion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ TensorGroups TensorReferenceGroup::accessedBySubtree(
auto schedule = partialSchedule(scop.scheduleRoot(), tree);

addSingletonReferenceGroups(
tensorGroups, scop.writes, domain, schedule, AccessType::Write);
tensorGroups, scop.mayWrites, domain, schedule, AccessType::Write);
addSingletonReferenceGroups(
tensorGroups, scop.reads, domain, schedule, AccessType::Read);

Expand Down
34 changes: 25 additions & 9 deletions tc/core/polyhedral/scop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ ScopUPtr Scop::makeScop(
auto tree = halide2isl::makeScheduleTree(paramSpace, components.stmt);
scop->scheduleTreeUPtr = std::move(tree.tree);
scop->reads = tree.reads;
scop->writes = tree.writes;
scop->mayWrites = tree.mayWrites;
scop->mustWrites = tree.mustWrites;
scop->halide.statements = std::move(tree.statements);
scop->halide.accesses = std::move(tree.accesses);
scop->halide.reductions = halide2isl::findReductions(components.stmt);
Expand Down Expand Up @@ -109,7 +110,8 @@ const isl::union_set Scop::domain() const {
std::ostream& operator<<(std::ostream& os, const Scop& s) {
os << "domain: " << s.domain() << "\n";
os << "reads: " << s.reads << "\n";
os << "writes: " << s.writes << "\n";
os << "mayWrites: " << s.mayWrites << "\n";
os << "mustWrites: " << s.mustWrites << "\n";
os << "schedule: " << *s.scheduleRoot() << "\n";
os << "idx: { ";
for (auto i : s.halide.idx) {
Expand Down Expand Up @@ -351,19 +353,29 @@ namespace {

using namespace tc::polyhedral;

// Compute the dependence using the given may/must sources, sinks and kills.
// Any of the inputs may be an empty (but non-null) union map.
// Dependence analysis removes the cases transitively covered by a must source
// or a kill.
isl::union_map computeDependences(
isl::union_map sources,
isl::union_map maySources,
isl::union_map mustSources,
isl::union_map sinks,
isl::union_map kills,
isl::schedule schedule) {
auto uai = isl::union_access_info(sinks);
uai = uai.set_may_source(sources);
uai = uai.set_may_source(maySources);
uai = uai.set_must_source(mustSources);
uai = uai.set_kill(kills);
uai = uai.set_schedule(schedule);
auto flow = uai.compute_flow();
return flow.get_may_dependence();
}

// Do the simplest possible dependence analysis.
// Live-range reordering needs tagged access relations to be available.
// Set up schedule constraints by performing the dependence analysis using
// access relations from "scop". Set up callbacks in the constraints depending
// on "scheduleOptions".
//
// The domain of the constraints is intersected with "restrictDomain" if it is
// provided.
isl::schedule_constraints makeScheduleConstraints(
Expand All @@ -373,12 +385,16 @@ isl::schedule_constraints makeScheduleConstraints(
auto schedule = toIslSchedule(scop.scheduleRoot());
auto firstChildNode = scop.scheduleRoot()->child({0});
auto reads = scop.reads.domain_factor_domain();
auto writes = scop.writes.domain_factor_domain();
auto mayWrites = scop.mayWrites.domain_factor_domain();
auto mustWrites = scop.mustWrites.domain_factor_domain();
auto empty = isl::union_map::empty(mustWrites.get_space());

// RAW
auto flowDeps = computeDependences(writes, reads, schedule);
auto flowDeps =
computeDependences(mayWrites, mustWrites, reads, mustWrites, schedule);
// WAR and WAW
auto falseDeps = computeDependences(writes.unite(reads), writes, schedule);
auto falseDeps = computeDependences(
mayWrites.unite(reads), empty, mayWrites, mustWrites, schedule);

auto allDeps = flowDeps.unite(falseDeps).coalesce();

Expand Down
14 changes: 11 additions & 3 deletions tc/core/polyhedral/scop.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ struct Scop {
res->globalParameterContext = scop.globalParameterContext;
res->halide = scop.halide;
res->reads = scop.reads;
res->writes = scop.writes;
res->mayWrites = scop.mayWrites;
res->mustWrites = scop.mustWrites;
res->scheduleTreeUPtr =
detail::ScheduleTree::makeScheduleTree(*scop.scheduleTreeUPtr);
res->treeSyncUpdateMap = scop.treeSyncUpdateMap;
Expand Down Expand Up @@ -115,7 +116,8 @@ struct Scop {
void specializeToContext() {
domain() = domain().intersect_params(globalParameterContext);
reads = reads.intersect_params(globalParameterContext);
writes = writes.intersect_params(globalParameterContext);
mayWrites = mayWrites.intersect_params(globalParameterContext);
mustWrites = mustWrites.intersect_params(globalParameterContext);
}

// Returns a set that specializes the named scop's subset of
Expand Down Expand Up @@ -442,8 +444,14 @@ struct Scop {
// This globalParameterContext lives in a parameter space.
isl::set globalParameterContext; // TODO: not too happy about this name

// Access relations.
// Elements in mayWrite may or may not be written by the execution, depending
// on some dynamic condition. Those in mustWrites are always written.
// Thefore, mayWrites do not participate in transitively-covered dependence
// removal.
isl::union_map reads;
isl::union_map writes;
isl::union_map mayWrites;
isl::union_map mustWrites;

private:
// By analogy with generalized functions, a ScheduleTree is a (piecewise
Expand Down
Loading