Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for compile time evaluation of StructInstanceMember #91

Merged
merged 3 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions integration_tests/nbody.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ Source: https://benchmarksgame-team.pages.debian.net/benchmarksgame/program/nbod
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"

const int nb = 5;
const double PI = 3.141592653589793;
const double SOLAR_MASS = 4 * PI * PI;
const int N = (nb - 1) * nb/2;
constexpr int nb = 5;
constexpr double PI = 3.141592653589793;
constexpr double SOLAR_MASS = 4 * PI * PI;
constexpr int N = (nb - 1) * nb/2;

void offset_momentum(const int k, xt::xtensor_fixed<double, xt::xshape<3, nb>>& v,
const xt::xtensor_fixed<double, xt::xshape<nb>>& mass) {
Expand Down Expand Up @@ -101,7 +101,7 @@ double energy(const xt::xtensor_fixed<double, xt::xshape<3, nb>>& x,
struct body {
double x, y, z, u, vx, vy, vz, vu, mass;

body(double x_, double y_, double z_, double u_,
constexpr body(double x_, double y_, double z_, double u_,
double vx_, double vy_, double vz_, double vu_,
double mass_) : x{x_}, y{y_}, z{z_}, u{u_}, vx{vx_},
vy{vy_}, vz{vz_}, vu{vu_}, mass{mass_} {
Expand All @@ -112,40 +112,40 @@ struct body {
int main() {

const double tstep = 0.01;
const double DAYS_PER_YEAR = 365.24;
constexpr double DAYS_PER_YEAR = 365.24;

const struct body jupiter = body(
constexpr struct body jupiter = body(
4.84143144246472090, -1.16032004402742839,
-1.03622044471123109e-01, 0.0, 1.66007664274403694e-03 * DAYS_PER_YEAR,
7.69901118419740425e-03 * DAYS_PER_YEAR,
-6.90460016972063023e-05 * DAYS_PER_YEAR, 0.0,
9.54791938424326609e-04 * SOLAR_MASS);

const struct body saturn = body(
constexpr struct body saturn = body(
8.34336671824457987, 4.12479856412430479,
-4.03523417114321381e-01, 0.0,
-2.76742510726862411e-03 * DAYS_PER_YEAR,
4.99852801234917238e-03 * DAYS_PER_YEAR,
2.30417297573763929e-05 * DAYS_PER_YEAR, 0.0,
2.85885980666130812e-04 * SOLAR_MASS);

const struct body uranus = body(
constexpr struct body uranus = body(
1.28943695621391310e+01, -1.51111514016986312e+01,
-2.23307578892655734e-01, 0.0,
2.96460137564761618e-03 * DAYS_PER_YEAR,
2.37847173959480950e-03 * DAYS_PER_YEAR,
-2.96589568540237556e-05 * DAYS_PER_YEAR, 0.0,
4.36624404335156298e-05 * SOLAR_MASS);

const struct body neptune = body(
constexpr struct body neptune = body(
1.53796971148509165e+01, -2.59193146099879641e+01,
1.79258772950371181e-01, 0.0,
2.68067772490389322e-03 * DAYS_PER_YEAR,
1.62824170038242295e-03 * DAYS_PER_YEAR,
-9.51592254519715870e-05 * DAYS_PER_YEAR, 0.0,
5.15138902046611451e-05 * SOLAR_MASS);

const struct body sun = body(0.0, 0.0, 0.0, 0.0, 0.0,
constexpr struct body sun = body(0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, SOLAR_MASS);

xt::xtensor_fixed<double, xt::xshape<nb>> mass = {
Expand Down
98 changes: 96 additions & 2 deletions src/lc/clang_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
Vec<ASR::stmt_t*>* default_stmt;
OneTimeUseBool is_break_stmt_present;
bool enable_fall_through;
std::map<ASR::symbol_t*, std::map<std::string, ASR::expr_t*>> struct2member_inits;

explicit ClangASTtoASRVisitor(clang::ASTContext *Context_,
Allocator& al_, ASR::asr_t*& tu_):
Expand Down Expand Up @@ -481,6 +482,26 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

bool TraverseCXXRecordDecl(clang::CXXRecordDecl* x) {
for( auto constructors = x->ctor_begin(); constructors != x->ctor_end(); constructors++ ) {
clang::CXXConstructorDecl* constructor = *constructors;
if( constructor->isTrivial() || constructor->isImplicit() ) {
continue ;
}
for( auto ctor = constructor->init_begin(); ctor != constructor->init_end(); ctor++ ) {
clang::CXXCtorInitializer* ctor_init = *ctor;
clang::Expr* init_expr = ctor_init->getInit();
if( init_expr->getStmtClass() == clang::Stmt::StmtClass::InitListExprClass ) {
init_expr = static_cast<clang::InitListExpr*>(init_expr)->getInit(0);
}
if( init_expr->getStmtClass() != clang::Stmt::StmtClass::ImplicitCastExprClass ||
static_cast<clang::ImplicitCastExpr*>(init_expr)->getSubExpr()->getStmtClass() !=
clang::Stmt::StmtClass::DeclRefExprClass ) {
throw std::runtime_error("Initialisation expression in constructor should "
"only be the argument itself.");
}
}
}

SymbolTable* parent_scope = current_scope;
current_scope = al.make_new<SymbolTable>(parent_scope);
std::string struct_name = x->getNameAsString();
Expand Down Expand Up @@ -576,6 +597,20 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
return true;
}

ASR::expr_t* evaluate_compile_time_value_for_StructInstanceMember(
ASR::expr_t* base, const std::string& member_name) {
if( ASR::is_a<ASR::Var_t>(*base) ) {
ASR::Var_t* var_t = ASR::down_cast<ASR::Var_t>(base);
ASR::symbol_t* v = ASRUtils::symbol_get_past_external(var_t->m_v);
if( struct2member_inits.find(v) == struct2member_inits.end() ) {
return nullptr;
}
return struct2member_inits[v][member_name];
}

return nullptr;
}

bool TraverseMemberExpr(clang::MemberExpr* x) {
TraverseStmt(x->getBase());
ASR::expr_t* base = ASRUtils::EXPR(tmp.get());
Expand All @@ -596,8 +631,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
struct_type_t->m_name, nullptr, 0, s2c(al, member_name),
ASR::accessType::Public));
current_scope->add_symbol(mangled_name, member);
tmp = ASR::make_StructInstanceMember_t(al, Lloc(x), base, member,
ASRUtils::symbol_type(member), nullptr);
tmp = ASR::make_StructInstanceMember_t(al, Lloc(x), base, member, ASRUtils::symbol_type(member),
evaluate_compile_time_value_for_StructInstanceMember(base, member_name));
} else if( special_function_map.find(member_name) != special_function_map.end() ) {
member_name_obj.set(member_name);
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseMemberExpr(x);
Expand Down Expand Up @@ -1308,6 +1343,13 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

bool TraverseCXXTemporaryObjectExpr(clang::CXXTemporaryObjectExpr *x) {
if( !x->getConstructor()->isConstexpr() ) {
throw std::runtime_error("Constructors for user-define types "
"must be defined with constexpr.");
}
if( static_cast<clang::CompoundStmt*>(x->getConstructor()->getBody())->size() > 0 ) {
throw std::runtime_error("Constructor for user-defined must have empty body.");
}
std::string type_name = x->getConstructor()->getNameAsString();
ASR::symbol_t* s = current_scope->resolve_symbol(type_name);
if( s == nullptr ) {
Expand All @@ -1322,6 +1364,11 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
ASR::call_arg_t call_arg;
call_arg.loc = Lloc(x);
call_arg.m_value = ASRUtils::EXPR(tmp.get());
if( !ASRUtils::is_value_constant(ASRUtils::expr_value(call_arg.m_value)) ) {
throw std::runtime_error("Constructor for user-defined types "
"must be initialised with constant values, " + std::to_string(i) +
"-th argument is not a constant.");
}
ASR::ttype_t* orig_type = ASRUtils::symbol_type(
struct_type_t->m_symtab->resolve_symbol(struct_type_t->m_members[i]));
ASR::ttype_t* arg_type = ASRUtils::expr_type(call_arg.m_value);
Expand All @@ -1338,6 +1385,48 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
return true;
}

void TraverseAPValue(clang::APValue& field) {
Location loc;
loc.first = 1; loc.last = 1;
switch( field.getKind() ) {
case clang::APValue::Int: {
tmp = ASR::make_IntegerConstant_t(al, loc, field.getInt().getLimitedValue(),
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, field.getInt().getBitWidth()/8)));
break;
}
case clang::APValue::Float: {
tmp = ASR::make_RealConstant_t(al, loc, field.getFloat().convertToDouble(),
ASRUtils::TYPE(ASR::make_Real_t(al, loc, 8)));
break;
}
default: {
throw std::runtime_error("APValue not supported for clang::APValue::" +
std::to_string(field.getKind()));
}
}
}

void evaluate_compile_time_value_for_Var(clang::APValue* ap_value, ASR::symbol_t* v) {
switch( ap_value->getKind() ) {
case clang::APValue::Struct: {
ASR::ttype_t* v_type = ASRUtils::type_get_past_const(ASRUtils::symbol_type(v));
if( !ASR::is_a<ASR::Struct_t>(*v_type) ) {
throw std::runtime_error("Expected ASR::Struct_t type found, " +
ASRUtils::type_to_str(v_type));
}
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(v_type);
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(
ASRUtils::symbol_get_past_external(struct_t->m_derived_type));
for( size_t i = 0; i < ap_value->getStructNumFields(); i++ ) {
clang::APValue& field = ap_value->getStructField(i);
TraverseAPValue(field);
struct2member_inits[v][struct_type_t->m_members[i]] = ASRUtils::EXPR(tmp.get());
}
break;
}
}
}

bool TraverseVarDecl(clang::VarDecl *x) {
std::string name = x->getName().str();
if( scopes.size() > 0 ) {
Expand Down Expand Up @@ -1385,6 +1474,11 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
is_stmt_created = true;
}
}

if( x->getEvaluatedValue() ) {
clang::APValue* ap_value = x->getEvaluatedValue();
evaluate_compile_time_value_for_Var(ap_value, v);
}
}
return true;
}
Expand Down
Loading