Skip to content

Commit

Permalink
Merge pull request #28 from czgdp1807/lc_19
Browse files Browse the repository at this point in the history
Support `xt::xtensor<...>`
  • Loading branch information
czgdp1807 authored Dec 21, 2023
2 parents 9a60262 + e321e3a commit 9f409f6
Showing 3 changed files with 61 additions and 10 deletions.
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -198,3 +198,5 @@ RUN(NAME array_06.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
RUN(NAME array_07.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
RUN(NAME array_08.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
44 changes: 44 additions & 0 deletions integration_tests/array_08.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <iostream>
#include <xtensor/xtensor.hpp>
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"

xt::xtensor<double, 2> matmul(
xt::xtensor<double, 2>& a,
xt::xtensor<double, 2>& b) {
xt::xtensor<double, 2> ab = xt::empty<double>({a.shape(0), b.shape(1)});
for( int i = 0; i < a.shape(0); i++ ) {
for( int j = 0; j < b.shape(1); j++ ) {
ab(i, j) = 0;
for( int k = 0; k < a.shape(1); k++ ) {
ab(i, j) += a(i, k) * b(k, j);
}
}
}
return ab;
}

int main() {

xt::xtensor<double, 2> a = xt::empty<double>({3, 2});
xt::xtensor<double, 2> b = xt::empty<double>({2, 3});
xt::xtensor<double, 2> c = xt::empty<double>({3, 3});

for( int i = 0; i < a.shape(0); i++ ) {
for( int j = 0; j < a.shape(1); j++ ) {
a(i, j) = i * a.shape(0) * j;
}
}

for( int i = 0; i < b.shape(0); i++ ) {
for( int j = 0; j < b.shape(1); j++ ) {
b(i, j) = i * b.shape(0) * j;
}
}

std::cout << a << b << std::endl;
c = matmul(a, b);
std::cout << c << std::endl;

return 0;
}
25 changes: 15 additions & 10 deletions src/lc/clang_ast_to_asr.h
Original file line number Diff line number Diff line change
@@ -477,10 +477,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}
tmp = ASR::make_ArrayItem_t(al, Lloc(x), obj,
array_indices.p, array_indices.size(),
ASRUtils::type_get_past_allocatable(
ASRUtils::type_get_past_pointer(
ASRUtils::type_get_past_array(
ASRUtils::expr_type(obj)))),
ASRUtils::extract_type(ASRUtils::expr_type(obj)),
ASR::arraystorageType::RowMajor, nullptr);
} else {
throw std::runtime_error("Only indexing arrays is supported for now with operator().");
@@ -534,6 +531,9 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
if( cxx_operator_name.size() > 0 ) {
func_name = cxx_operator_name;
} else if( member_name.size() > 0 ) {
if( callee == nullptr ) {
throw std::runtime_error("Callee object not available.");
}
func_name = member_name;
} else {
ASR::Var_t* callee_Var = ASR::down_cast<ASR::Var_t>(callee);
@@ -667,8 +667,12 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

bool TraverseCallExpr(clang::CallExpr *x) {
cxx_operator_name.clear();
TraverseStmt(x->getCallee());
ASR::expr_t* callee = ASRUtils::EXPR(tmp);
ASR::expr_t* callee = nullptr;
if( tmp != nullptr ) {
callee = ASRUtils::EXPR(tmp);
}
if( check_and_handle_special_function(x, callee) ) {
return true;
}
@@ -843,7 +847,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
assignment_target = var;
TraverseStmt(x->getInit());
assignment_target = nullptr;
if( tmp != nullptr ) {
if( tmp != nullptr && !is_stmt_created ) {
ASR::expr_t* init_val = ASRUtils::EXPR(tmp);
add_reshape_if_needed(init_val, var);
tmp = ASR::make_Assignment_t(al, Lloc(x), var, init_val, nullptr);
@@ -1060,16 +1064,17 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

bool TraverseDeclRefExpr(clang::DeclRefExpr* x) {
cxx_operator_name.clear();
std::string name = x->getNameInfo().getAsString();
if( name == "operator<<" || name == "cout" ||
ASR::symbol_t* sym = resolve_symbol(name);
if( sym == nullptr &&
(name == "operator<<" || name == "cout" ||
name == "endl" || name == "operator()" ||
name == "operator+" || name == "operator=" ||
name == "view" || name == "empty" ) {
name == "view" || name == "empty" || name == "printf") ) {
cxx_operator_name = name;
return true;
}
ASR::symbol_t* sym = resolve_symbol(name);
LCOMPILERS_ASSERT(sym != nullptr);
tmp = ASR::make_Var_t(al, Lloc(x), sym);
is_stmt_created = false;
return true;

0 comments on commit 9f409f6

Please sign in to comment.