diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index 911452931e..791fa99f63 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -52,13 +52,15 @@ namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::vec_cast; template struct AbsFunctor { using is_constant = typename std::false_type; // constexpr resT constant_value = resT{}; - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -87,6 +89,40 @@ template struct AbsFunctor } } + template + sycl::vec operator()(const sycl::vec &in) + { + if constexpr (std::is_integral::value) { + if constexpr (std::is_same_v || + std::is_unsigned::value) { + return in; + } + else { + auto const &res_vec = sycl::abs(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + + return vec_cast(res_vec); + } + } + } + else { + auto const &res_vec = sycl::fabs(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } + } + private: template realT cabs(std::complex const &z) const { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index c8cd8ef18c..8c05a8a4fd 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -50,6 +50,7 @@ namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::vec_cast; template struct CosFunctor { @@ -59,7 +60,8 @@ template struct CosFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -165,6 +167,20 @@ template struct CosFunctor return std::cos(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::cos(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template struct Expm1Functor { @@ -60,7 +61,8 @@ template struct Expm1Functor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -132,6 +134,20 @@ template struct Expm1Functor return std::expm1(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::expm1(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template struct LogFunctor { @@ -60,7 +61,8 @@ template struct LogFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -79,6 +81,20 @@ template struct LogFunctor return std::log(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::log(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template struct Log1pFunctor @@ -60,7 +61,8 @@ template struct Log1pFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -99,6 +101,20 @@ template struct Log1pFunctor return std::log1p(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::log1p(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template struct SinFunctor { @@ -58,7 +59,8 @@ template struct SinFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -181,6 +183,20 @@ template struct SinFunctor return std::sin(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::sin(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template struct SqrtFunctor { @@ -62,7 +63,8 @@ template struct SqrtFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -95,6 +97,20 @@ template struct SqrtFunctor } } + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::sqrt(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } + private: template std::complex csqrt(std::complex const &z) const {