Skip to content

Commit

Permalink
Merge pull request #102 from czgdp1807/union_01
Browse files Browse the repository at this point in the history
Ported ``integration_tests/union_01.py`` from LPython and added support for union in LC to compile it
  • Loading branch information
czgdp1807 authored Mar 5, 2024
2 parents 0d1fb2e + bf23900 commit 0584a31
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 12 deletions.
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,5 @@ RUN(NAME enum_01.cpp LABELS gcc llvm NOFAST)
RUN(NAME enum_02.cpp LABELS gcc llvm NOFAST)
RUN(NAME enum_03.cpp LABELS gcc llvm NOFAST)
RUN(NAME enum_04.cpp LABELS gcc llvm NOFAST)

RUN(NAME union_01.cpp LABELS gcc llvm NOFAST)
37 changes: 37 additions & 0 deletions integration_tests/union_01.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include <iostream>

union u_type {
int32_t integer32;
float real32;
double real64;
int64_t integer64;
};

#define assert(cond) if( !(cond) ) { \
exit(2); \
} \

void test_union() {
union u_type unionobj;
unionobj.integer32 = 1;
std::cout << unionobj.integer32 << std::endl;
assert( unionobj.integer32 == 1 );

unionobj.real32 = 2.0;
std::cout << unionobj.real32 << std::endl;
assert( abs(unionobj.real32 - 2.0) <= 1e-6 );

unionobj.real64 = 3.5;
std::cout << unionobj.real64 << std::endl;
assert( abs(unionobj.real64 - 3.5) <= 1e-12 );

unionobj.integer64 = 4;
std::cout << unionobj.integer64 << std::endl;
assert( unionobj.integer64 == 4 );
}

int main() {

test_union();

}
48 changes: 38 additions & 10 deletions src/lc/clang_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,12 +472,16 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::Record ) {
const clang::CXXRecordDecl* record_type = clang_type->getAsCXXRecordDecl();
std::string struct_name = record_type->getNameAsString();
ASR::symbol_t* struct_t = current_scope->resolve_symbol(struct_name);
if( !struct_t ) {
throw std::runtime_error(struct_name + " not defined.");
std::string name = record_type->getNameAsString();
ASR::symbol_t* type_t = current_scope->resolve_symbol(name);
if( !type_t ) {
throw std::runtime_error(name + " not defined.");
}
if( clang_type->isUnionType() ) {
type = ASRUtils::TYPE(ASR::make_Union_t(al, l, type_t));
} else {
type = ASRUtils::TYPE(ASR::make_Struct_t(al, l, type_t));
}
type = ASRUtils::TYPE(ASR::make_Struct_t(al, l, struct_t));
} else {
throw std::runtime_error("clang::QualType not yet supported " +
std::string(clang_type->getTypeClassName()));
Expand Down Expand Up @@ -518,17 +522,24 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

SymbolTable* parent_scope = current_scope;
current_scope = al.make_new<SymbolTable>(parent_scope);
std::string struct_name = x->getNameAsString();
std::string name = x->getNameAsString();
Vec<char*> field_names; field_names.reserve(al, 1);
for( auto field = x->field_begin(); field != x->field_end(); field++ ) {
clang::FieldDecl* field_decl = *field;
TraverseFieldDecl(field_decl);
field_names.push_back(al, s2c(al, field_decl->getNameAsString()));
}
ASR::symbol_t* struct_t = ASR::down_cast<ASR::symbol_t>(ASR::make_StructType_t(al, Lloc(x), current_scope,
s2c(al, struct_name), nullptr, 0, field_names.p, field_names.size(), ASR::abiType::Source,
ASR::accessType::Public, false, x->isAbstract(), nullptr, 0, nullptr, nullptr));
parent_scope->add_symbol(struct_name, struct_t);
if( x->isUnion() ) {
ASR::symbol_t* union_t = ASR::down_cast<ASR::symbol_t>(ASR::make_UnionType_t(al, Lloc(x), current_scope,
s2c(al, name), nullptr, 0, field_names.p, field_names.size(), ASR::abiType::Source,
ASR::accessType::Public, nullptr, 0, nullptr));
parent_scope->add_symbol(name, union_t);
} else {
ASR::symbol_t* struct_t = ASR::down_cast<ASR::symbol_t>(ASR::make_StructType_t(al, Lloc(x), current_scope,
s2c(al, name), nullptr, 0, field_names.p, field_names.size(), ASR::abiType::Source,
ASR::accessType::Public, false, x->isAbstract(), nullptr, 0, nullptr, nullptr));
parent_scope->add_symbol(name, struct_t);
}
current_scope = parent_scope;
return true;
}
Expand Down Expand Up @@ -734,6 +745,23 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
current_scope->add_symbol(mangled_name, member);
tmp = ASR::make_StructInstanceMember_t(al, Lloc(x), base, member, ASRUtils::symbol_type(member),
evaluate_compile_time_value_for_StructInstanceMember(base, member_name));
} else if( ASR::is_a<ASR::Union_t>(*base_type) ) {
ASR::Union_t* union_t = ASR::down_cast<ASR::Union_t>(base_type);
ASR::UnionType_t* union_type_t = ASR::down_cast<ASR::UnionType_t>(
ASRUtils::symbol_get_past_external(union_t->m_union_type));
ASR::symbol_t* member = union_type_t->m_symtab->resolve_symbol(member_name);
if( !member ) {
throw std::runtime_error(member_name + " not found in the scope of " + union_type_t->m_name);
}
std::string mangled_name = current_scope->get_unique_name(
member_name + "@" + union_type_t->m_name);
member = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(
al, Lloc(x), current_scope, s2c(al, mangled_name), member,
union_type_t->m_name, nullptr, 0, s2c(al, member_name),
ASR::accessType::Public));
current_scope->add_symbol(mangled_name, member);
tmp = ASR::make_UnionInstanceMember_t(al, Lloc(x),
base, member, ASRUtils::symbol_type(member), nullptr);
} else if( special_function_map.find(member_name) != special_function_map.end() ) {
member_name_obj.set(member_name);
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseMemberExpr(x);
Expand Down
9 changes: 8 additions & 1 deletion src/libasr/asr_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
ASR::Module_t *m = ASRUtils::get_sym_module(x.m_external);
ASR::StructType_t* sm = nullptr;
ASR::EnumType_t* em = nullptr;
ASR::UnionType_t* um = nullptr;
ASR::Function_t* fm = nullptr;
bool is_valid_owner = false;
is_valid_owner = m != nullptr && ((ASR::symbol_t*) m == ASRUtils::get_asr_owner(x.m_external));
Expand All @@ -711,13 +712,17 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
ASR::symbol_t* asr_owner_sym = ASRUtils::get_asr_owner(x.m_external);
is_valid_owner = (ASR::is_a<ASR::StructType_t>(*asr_owner_sym) ||
ASR::is_a<ASR::EnumType_t>(*asr_owner_sym) ||
ASR::is_a<ASR::Function_t>(*asr_owner_sym));
ASR::is_a<ASR::Function_t>(*asr_owner_sym) ||
ASR::is_a<ASR::UnionType_t>(*asr_owner_sym));
if( ASR::is_a<ASR::StructType_t>(*asr_owner_sym) ) {
sm = ASR::down_cast<ASR::StructType_t>(asr_owner_sym);
asr_owner_name = sm->m_name;
} else if( ASR::is_a<ASR::EnumType_t>(*asr_owner_sym) ) {
em = ASR::down_cast<ASR::EnumType_t>(asr_owner_sym);
asr_owner_name = em->m_name;
} else if( ASR::is_a<ASR::UnionType_t>(*asr_owner_sym) ) {
um = ASR::down_cast<ASR::UnionType_t>(asr_owner_sym);
asr_owner_name = um->m_name;
} else if( ASR::is_a<ASR::Function_t>(*asr_owner_sym) ) {
fm = ASR::down_cast<ASR::Function_t>(asr_owner_sym);
asr_owner_name = fm->m_name;
Expand Down Expand Up @@ -746,6 +751,8 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
s = em->m_symtab->resolve_symbol(std::string(x.m_original_name));
} else if( fm ) {
s = fm->m_symtab->resolve_symbol(std::string(x.m_original_name));
} else if( um ) {
s = um->m_symtab->resolve_symbol(std::string(x.m_original_name));
}
require(s != nullptr,
"ExternalSymbol::m_original_name ('"
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr(*x.m_v);
ptr_loads = ptr_loads_copy;
llvm::Value* union_llvm = tmp;
ASR::Variable_t* member_var = ASR::down_cast<ASR::Variable_t>(x.m_m);
ASR::Variable_t* member_var = ASR::down_cast<ASR::Variable_t>(
ASRUtils::symbol_get_past_external(x.m_m));
ASR::ttype_t* member_type_asr = ASRUtils::get_contained_type(member_var->m_type);
if( ASR::is_a<ASR::Struct_t>(*member_type_asr) ) {
ASR::Struct_t* d = ASR::down_cast<ASR::Struct_t>(member_type_asr);
Expand Down

0 comments on commit 0584a31

Please sign in to comment.