Skip to content

Commit

Permalink
DEV: Add support for declaring functions accepting torch::Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
czgdp1807 committed Mar 22, 2024
1 parent d22ae00 commit 85677ba
Showing 1 changed file with 74 additions and 13 deletions.
87 changes: 74 additions & 13 deletions src/lc/clang_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ enum SpecialFunc {
Data,
Reserve,

TorchOnes
TorchOnes,
TorchEmpty,
};

std::map<std::string, SpecialFunc> special_function_map = {
Expand Down Expand Up @@ -79,7 +80,9 @@ std::map<std::string, SpecialFunc> special_function_map = {
{"clear", SpecialFunc::Clear},
{"data", SpecialFunc::Data},
{"reserve", SpecialFunc::Reserve},

{"torch::ones", SpecialFunc::TorchOnes},
{"torch::empty", SpecialFunc::TorchEmpty},
};

class OneTimeUseString {
Expand Down Expand Up @@ -603,7 +606,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}
*is_third_party_cpp_array = true;
*array_type = ThirdPartyCPPArrayTypes::PyTorchArray;
type = ASRUtils::TYPE(ASR::make_Array_t(al, l, ASRUtils::TYPE(ASR::make_Real_t(al, l, 4)),
type = ASRUtils::TYPE(ASR::make_Array_t(al, l, ASRUtils::TYPE(ASR::make_Real_t(al, l, 8)),
nullptr, 0, ASR::array_physical_typeType::DescriptorArray));
return type;
} else if( qualified_name == "c10::Scalar" ) {
Expand Down Expand Up @@ -793,7 +796,18 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
if( name == "" ) {
name = current_scope->get_unique_name("param");
}
ASR::ttype_t* type = ClangTypeToASRType(x->getType());

bool is_third_party_array_type = false;
ThirdPartyCPPArrayTypes array_type;
Vec<ASR::dimension_t> shape_result; shape_result.reserve(al, 1);
ASR::ttype_t* type = ClangTypeToASRType(x->getType(), &shape_result,
&array_type, &is_third_party_array_type);
if( is_third_party_array_type &&
array_type == ThirdPartyCPPArrayTypes::PyTorchArray ) {
if( !x->getDefaultArg() ) {
throw std::runtime_error("torch::Tensor type arguments must have default arguments.");
}
}
ASR::intentType intent_type = ASR::intentType::InOut;
if( ASR::is_a<ASR::Const_t>(*type) ) {
intent_type = ASR::intentType::In;
Expand All @@ -806,20 +820,30 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
ASRUtils::is_array(type) ) {
throw std::runtime_error("Array objects should be passed by reference only.");
}
tmp = ASR::make_Variable_t(al, Lloc(x), current_scope, s2c(al, name),
nullptr, 0, intent_type, nullptr, nullptr,
ASR::storage_typeType::Default, type, nullptr, ASR::abiType::Source,
ASR::accessType::Public, ASR::presenceType::Required, false);
ASR::symbol_t* tmp_sym = ASR::down_cast<ASR::symbol_t>(tmp.get());
current_scope->add_symbol(name, tmp_sym);
ASR::asr_t* tmp_ = ASR::make_Var_t(al, Lloc(x), tmp_sym);

clang::Expr *init = x->getDefaultArg();
ASR::expr_t* asr_init = nullptr;
if (init) {
ASR::expr_t* assignment_target_copy = assignment_target;
assignment_target = ASRUtils::EXPR(tmp_);
TraverseStmt(init);
asr_init = ASRUtils::EXPR(tmp.get());
if( tmp != nullptr && !is_stmt_created ) {
asr_init = ASRUtils::EXPR(tmp.get());
}
}

tmp = ASR::make_Variable_t(al, Lloc(x), current_scope, s2c(al, name),
nullptr, 0, ASR::intentType::InOut, asr_init, nullptr,
ASR::storage_typeType::Default, type, nullptr, ASR::abiType::Source,
ASR::accessType::Public, ASR::presenceType::Required, false);
ASR::symbol_t* tmp_sym = ASR::down_cast<ASR::symbol_t>(tmp.get());
current_scope->add_symbol(name, tmp_sym);
tmp = ASR::make_Var_t(al, Lloc(x), tmp_sym);
// TODO: For PyTorch tensor create an intrinsic empty
// and then fill the initialiser value with a call
// to that intrinsic.

tmp = tmp_;
is_stmt_created = false;
return true;
}
Expand Down Expand Up @@ -1340,7 +1364,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

ASR::expr_t* shape_arg = args.p[0];
ASR::expr_t* one = ASRUtils::get_constant_one_with_given_type(
al, ASRUtils::TYPE(ASR::make_Real_t(al, Lloc(x), 4)));
al, ASRUtils::TYPE(ASR::make_Real_t(al, Lloc(x), 8)));
Vec<ASR::dimension_t> expr_dims; expr_dims.reserve(al, 1);
if( ASR::is_a<ASR::IntegerConstant_t>(*shape_arg) ) {
ASR::dimension_t expr_dim;
Expand All @@ -1350,7 +1374,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
expr_dim.m_length = shape_arg;
expr_dims.push_back(al, expr_dim);
ASR::ttype_t* type = ASRUtils::TYPE(ASR::make_Array_t(al, Lloc(x),
ASRUtils::TYPE(ASR::make_Real_t(al, Lloc(x), 4)), expr_dims.p,
ASRUtils::TYPE(ASR::make_Real_t(al, Lloc(x), 8)), expr_dims.p,
expr_dims.size(), ASR::array_physical_typeType::FixedSizeArray));
int num_ones = ASR::down_cast<ASR::IntegerConstant_t>(shape_arg)->m_n;
Vec<ASR::expr_t*> ones_vec; ones_vec.reserve(al, num_ones);
Expand Down Expand Up @@ -1497,6 +1521,37 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
} else {
throw std::runtime_error("Only {...} is allowed for supplying shape to xt::empty.");
}
} else if (sf == SpecialFunc::TorchEmpty) {
if( args.size() < 1 ) { // Ignore the last two
throw std::runtime_error("torch::empty must be provided with shape.");
}
if( assignment_target != nullptr && ASRUtils::expr_intent(assignment_target) == ASR::intentType::Local) {
throw std::runtime_error("torch::empty isn't handled in assignment statement yet.");
}

if( ASR::is_a<ASR::ArrayConstant_t>(*args.p[0]) ) {
ASR::ArrayConstant_t* array_constant = ASR::down_cast<ASR::ArrayConstant_t>(args.p[0]);

Vec<ASR::dimension_t> empty_dims; empty_dims.reserve(al, array_constant->n_args);
for( size_t idim = 0; idim < array_constant->n_args; idim++ ) {
ASR::dimension_t empty_dim;
empty_dim.loc = Lloc(x);
empty_dim.m_start = ASRUtils::get_constant_zero_with_given_type(
al, ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4)));
empty_dim.m_length = nullptr;
empty_dims.push_back(al, empty_dim);
}
ASR::ttype_t* type = ASRUtils::TYPE(ASR::make_Array_t(al, Lloc(x),
ASRUtils::extract_type(ASRUtils::expr_type(assignment_target)),
empty_dims.p, empty_dims.size(), ASR::array_physical_typeType::DescriptorArray));
type = ASRUtils::TYPE(ASR::make_Allocatable_t(al, Lloc(x), type));
ASR::down_cast<ASR::Variable_t>(
ASR::down_cast<ASR::Var_t>(assignment_target)->m_v)->m_type = type;
tmp = nullptr;
is_stmt_created = false;
} else {
throw std::runtime_error("Only {...} is allowed for supplying shape to xt::empty.");
}
} else if (sf == SpecialFunc::Iota) {
tmp = ASR::make_ComplexConstant_t(al, Lloc(x), 0.0, 1.0,
ASRUtils::TYPE(ASR::make_Complex_t(al, Lloc(x), 8)));
Expand Down Expand Up @@ -1932,6 +1987,12 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
ThirdPartyCPPArrayTypes array_type; bool is_third_party_array_type = false;
ASR::ttype_t *asr_type = ClangTypeToASRType(x->getType(), &xshape_result,
&array_type, &is_third_party_array_type);
if( is_third_party_array_type &&
array_type == ThirdPartyCPPArrayTypes::PyTorchArray ) {
if( !x->hasInit() ) {
throw std::runtime_error("torch::Tensor variables must have initialiser value.");
}
}
ASR::symbol_t *v = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(al, Lloc(x),
current_scope, s2c(al, name), nullptr, 0, ASR::intentType::Local, nullptr, nullptr,
ASR::storage_typeType::Default, asr_type, nullptr, ASR::abiType::Source,
Expand Down

0 comments on commit 85677ba

Please sign in to comment.