diff --git a/src/cpp.cc b/src/cpp.cc index ec62452d..0470f39d 100644 --- a/src/cpp.cc +++ b/src/cpp.cc @@ -105,6 +105,59 @@ void Printer::Cpp::print(const Statement::For &stmt) { stream << stmt.statements; } +void Printer::Cpp::print(const Statement::SYCL_Host_Accessor_Decl &stmt) { + assert(stmt.name); + + stream << indent() + << "sycl::host_accessor " << stmt.name->name << "_hacc{" + << stmt.name->name << "};" + << endl; +} + + +void Printer::Cpp::print(const Statement::SYCL_Buffer_Decl &stmt) { + assert(stmt.type); + assert(stmt.dimension); + assert(stmt.name); + assert(stmt.value); + + stream << indent() << + "sycl::buffer<" << *stmt.type << "," << stmt.dimension << "> " << + *stmt.name << "(sycl::range<" << stmt.dimension << ">" << + "(" << *stmt.value->name << "));" << endl; +} + +void Printer::Cpp::print(const Statement::SYCL_Accessor_Decl &stmt) { + assert(stmt.variable); + assert(stmt.conext); + + stream << indent() << + "auto " << *stmt.variable->name + "_acc = sycl::accessor{" + << *stmt.variable->name << ", " << *stmt.context->name << ", "; + + if (*stmt.write && *stmt.read) { + stream << "sycl::access_mode::read_write"; + } else if (*stmt.write) { + stream << "sycl::access_mode::write"; + } else { + stream << "sycl::access_mode::read"; + } + + stream << "};" << endl; +} + +void Printer::Cpp::print(const Statement::SYCL_Submit_Kernel &stmt) { + assert(stmt.queue); + assert(stmt.context); + + stream << indent() << + *stmt.queue->name << + ".submit([&]sycl::handler &" << *stmt.context->name << ") "; + + stream << stmt.statements; + stream << ");" << endl; +} + void Printer::Cpp::print(const Statement::While &stmt) { stream << indent() << "while(" << stmt.expr() << ")\n"; stream << stmt.statements; diff --git a/src/cpp.hh b/src/cpp.hh index 60dbe3f8..20a6e903 100644 --- a/src/cpp.hh +++ b/src/cpp.hh @@ -150,6 +150,11 @@ class Cpp : public Base { void print(const Statement::Hash_Decl &stmt); void print(const Statement::Marker_Decl &stmt); + void print(const Statement::SYCL_Accessor_Decl &stmt); + void print(const Statement::SYCL_Submit_Kernel &stmt); + void print(const Statement::SYCL_Buffer_Decl &stmt); + void print(const Statement::SYCL_Host_Accessor_Decl &stmt); + void print(const Fn_Def &fn_def); void print(const Operator &op); diff --git a/src/cyk.cc b/src/cyk.cc index fc61bffa..dce3f5db 100644 --- a/src/cyk.cc +++ b/src/cyk.cc @@ -26,6 +26,7 @@ #include #include #include +#include "statement.hh" static const char *MUTEX = "mutex"; static const char *VARNAME_OuterLoop1 = "outer_loop_1_idx"; @@ -1299,7 +1300,28 @@ Fn_Def *print_CYK(const AST &ast) { } } } - fn_cyk->stmts.push_back(new Statement::CustomCode("#pragma omp parallel")); + + int dimension = 1; + std::string name = "test"; + Statement::Var_Decl *value = new Statement::Var_Decl( + new Type::String, "test_value"); + + fn_cyk->stmts.push_back( + new Statement::SYCL_Buffer_Decl(new Type::Int, dimension, value, value)); + + Statement::Var_Decl *queue = new Statement::Var_Decl( + new Type::External("sycl::queue"), "q"); + + fn_cyk->stmts.push_back(queue); + + Statement::SYCL_Submit_Kernel *blk_sycl = new Statement::SYCL_Submit_Kernel( + queue, new Statement::Var_Decl( + new Type::External("sycl::handler&"), "cgh")); + + bool* test = new bool(true); + blk_sycl->statements.push_back( + new Statement::SYCL_Accessor_Decl(value, value, test, test)); + Statement::Block *blk_parallel = new Statement::Block(); if (ast.checkpoint && ast.checkpoint->cyk) { diff --git a/src/printer.cc b/src/printer.cc index d4777963..e67d7d96 100644 --- a/src/printer.cc +++ b/src/printer.cc @@ -78,6 +78,12 @@ void Printer::Base::print(const Statement::Table_Decl &stmt) {} void Printer::Base::print(const std::list &stmts) {} +void Printer::Base::print(const Statement::SYCL_Buffer_Decl &stmt) {} +void Printer::Base::print(const Statement::SYCL_Submit_Kernel &stmt) {} +void Printer::Base::print(const Statement::SYCL_Accessor_Decl &stmt) {} +void Printer::Base::print(const Statement::SYCL_Host_Accessor_Decl &stmt) {} + + void Printer::Base::print(const Type::List &t) {} void Printer::Base::print(const Type::Tuple &t) {} void Printer::Base::print(const Type::TupleDef &t) {} diff --git a/src/printer.hh b/src/printer.hh index ef3436f4..ed25fd31 100644 --- a/src/printer.hh +++ b/src/printer.hh @@ -124,6 +124,12 @@ class Base { virtual void print(const Statement::Table_Decl &stmt); + virtual void print(const Statement::SYCL_Buffer_Decl &stmt); + virtual void print(const Statement::SYCL_Submit_Kernel &stmt); + virtual void print(const Statement::SYCL_Accessor_Decl &stmt); + virtual void print(const Statement::SYCL_Host_Accessor_Decl &stmt); + + virtual void print(const Expr::Base &); virtual void print(const Type::Base &); virtual void print(const Var_Acc::Base &); diff --git a/src/statement.cc b/src/statement.cc index e911c304..c53efabf 100644 --- a/src/statement.cc +++ b/src/statement.cc @@ -31,6 +31,7 @@ #include "statement.hh" +#include "statement/base.hh" #include "var_acc.hh" #include "expr.hh" @@ -71,6 +72,23 @@ Statement::Var_Decl::Var_Decl(::Type::Base *t, std::string *n, Expr::Base *e) : Base(VAR_DECL), type(t), name(n), rhs(e) { } +Statement::SYCL_Buffer_Decl::SYCL_Buffer_Decl( + ::Type::Base *t, int d, Var_Decl *v, Var_Decl *n) + : Base(VAR_DECL), type(t), dimension(d), value(v), name(n) { +} + +Statement::SYCL_Accessor_Decl::SYCL_Accessor_Decl( + Var_Decl *v, Var_Decl *c, bool *r, bool *w) + : Base(VAR_DECL), variable(v), context(c), read(r), write(w) { +} + +Statement::SYCL_Host_Accessor_Decl::SYCL_Host_Accessor_Decl(Var_Decl *n) + : Base(VAR_DECL), name(n) { + } + +Statement::SYCL_Submit_Kernel::SYCL_Submit_Kernel(Var_Decl *q, Var_Decl *c) + : Block_Base(BLOCK), queue(q), context(c) { + } Statement::Var_Decl *Statement::Var_Decl::clone() const { Var_Decl *ret = new Var_Decl(*this); @@ -82,6 +100,22 @@ Statement::Var_Decl *Statement::Var_Decl::clone() const { } +void Statement::SYCL_Accessor_Decl::print(Printer::Base &p) const { + p.print(*this); +} + +void Statement::SYCL_Buffer_Decl::print(Printer::Base &p) const { +p.print(*this); +} + +void Statement::SYCL_Submit_Kernel::print(Printer::Base &p) const { + p.print(*this); +} + +void Statement::SYCL_Host_Accessor_Decl::print(Printer::Base &p) const { + p.print(*this); +} + void Statement::Var_Decl::print(Printer::Base &p) const { p.print(*this); } diff --git a/src/statement.hh b/src/statement.hh index e20a432d..b468921b 100644 --- a/src/statement.hh +++ b/src/statement.hh @@ -194,6 +194,75 @@ class Increase : public Base { }; +/** + * @brief Declare a host accessor for Variables + * @example sycl::host_accessor result{results}; + * @param name the name of the variable you want a host accessor for + */ +class SYCL_Host_Accessor_Decl : public Base { + public: + Var_Decl *name; + + void print(Printer::Base &p) const; + + explicit SYCL_Host_Accessor_Decl(Var_Decl *n); +}; +/** + * @example auto aResult = sycl::accessor{results, cgh, sycl::read_write}; + * @brief Declare a Accessor for SYCL Kernels + * @param variable Var_Decl name and variable to create accessor to + * @param context Context defaults to cgh + * @param access_mode Access Mode (Read or Write) + */ +class SYCL_Accessor_Decl : public Base { + public: + Var_Decl *variable; + Var_Decl *context; + bool *read; + bool *write; + + void print(Printer::Base &p) const; + + SYCL_Accessor_Decl(Var_Decl *v, Var_Decl *c, bool *r, bool *w); +}; + +/** + * @example q.submit([&](sycl::handler &cgh) { ... }); + * @brief Submit Kernel + * @param queue The Queue to add the Kernel to + * @param context The Context Handler for the Kernel + */ +class SYCL_Submit_Kernel : public Block_Base { + public: + Var_Decl *queue; + Var_Decl *context; + + void print(Printer::Base &p) const; + + SYCL_Submit_Kernel(Var_Decl *q, Var_Decl *c); +}; + +/** + * @example sycl::buffer results(sycl::range<1>(n*m)); + * @brief A function to summarize and create an buffer + * + * @param type What type the buffer should hold + * @param dimension + * @param name + * @param size + */ +class SYCL_Buffer_Decl : public Base { + public: + ::Type::Base *type; + int dimension; + Var_Decl *name; + Var_Decl *value; + + SYCL_Buffer_Decl(::Type::Base *t, int d, Var_Decl *v, Var_Decl *n); + + void print(Printer::Base &p) const; +}; + class Var_Decl : public Base { private: Bool use_as_itr;