Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sycl::vec overloads for elementwise functions #1223

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct AbsFunctor
{

using is_constant = typename std::false_type;
// constexpr resT constant_value = resT{};
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;

Expand Down Expand Up @@ -87,6 +89,40 @@ template <typename argT, typename resT> struct AbsFunctor
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
if constexpr (std::is_integral<argT>::value) {
if constexpr (std::is_same_v<argT, bool> ||
std::is_unsigned<argT>::value) {
return in;
}
else {
auto const &res_vec = sycl::abs(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {

return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
}
else {
auto const &res_vec = sycl::fabs(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
}

private:
template <typename realT> realT cabs(std::complex<realT> const &z) const
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct CosFunctor
{
Expand All @@ -59,7 +60,8 @@ template <typename argT, typename resT> struct CosFunctor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand Down Expand Up @@ -165,6 +167,20 @@ template <typename argT, typename resT> struct CosFunctor
return std::cos(in);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::cos(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
};

template <typename argTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct Expm1Functor
{
Expand All @@ -60,7 +61,8 @@ template <typename argT, typename resT> struct Expm1Functor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand Down Expand Up @@ -132,6 +134,20 @@ template <typename argT, typename resT> struct Expm1Functor
return std::expm1(in);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::expm1(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
};

template <typename argTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct LogFunctor
{
Expand All @@ -60,7 +61,8 @@ template <typename argT, typename resT> struct LogFunctor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand All @@ -79,6 +81,20 @@ template <typename argT, typename resT> struct LogFunctor
return std::log(in);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::log(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
};

template <typename argTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

// TODO: evaluate precision against alternatives
template <typename argT, typename resT> struct Log1pFunctor
Expand All @@ -60,7 +61,8 @@ template <typename argT, typename resT> struct Log1pFunctor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand Down Expand Up @@ -99,6 +101,20 @@ template <typename argT, typename resT> struct Log1pFunctor
return std::log1p(in);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::log1p(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
};

template <typename argTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct SinFunctor
{
Expand All @@ -58,7 +59,8 @@ template <typename argT, typename resT> struct SinFunctor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand Down Expand Up @@ -181,6 +183,20 @@ template <typename argT, typename resT> struct SinFunctor
return std::sin(in);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::sin(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
};

template <typename argTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct SqrtFunctor
{
Expand All @@ -62,7 +63,8 @@ template <typename argT, typename resT> struct SqrtFunctor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand Down Expand Up @@ -95,6 +97,20 @@ template <typename argT, typename resT> struct SqrtFunctor
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::sqrt(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}

private:
template <typename T> std::complex<T> csqrt(std::complex<T> const &z) const
{
Expand Down