forked from kokkos/kokkos-kernels
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement batched serial ger (kokkos#2491)
* implement batched serial ger Signed-off-by: Yuuichi Asahi <[email protected]> * fixx: view constructors for x and y in ger test Signed-off-by: Yuuichi Asahi <[email protected]> --------- Signed-off-by: Yuuichi Asahi <[email protected]> Co-authored-by: Yuuichi Asahi <[email protected]>
- Loading branch information
1 parent
15d3ee3
commit 8912c6d
Showing
5 changed files
with
688 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
|
||
#ifndef KOKKOSBATCHED_GER_SERIAL_IMPL_HPP_ | ||
#define KOKKOSBATCHED_GER_SERIAL_IMPL_HPP_ | ||
|
||
#include <KokkosBatched_Util.hpp> | ||
#include "KokkosBatched_Ger_Serial_Internal.hpp" | ||
|
||
namespace KokkosBatched { | ||
namespace Impl { | ||
template <typename XViewType, typename YViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int checkGerInput([[maybe_unused]] const XViewType &x, | ||
[[maybe_unused]] const YViewType &y, | ||
[[maybe_unused]] const AViewType &A) { | ||
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::ger: XViewType is not a Kokkos::View."); | ||
static_assert(Kokkos::is_view_v<YViewType>, "KokkosBatched::ger: YViewType is not a Kokkos::View."); | ||
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::ger: AViewType is not a Kokkos::View."); | ||
static_assert(XViewType::rank == 1, "KokkosBatched::ger: XViewType must have rank 1."); | ||
static_assert(YViewType::rank == 1, "KokkosBatched::ger: YViewType must have rank 1."); | ||
static_assert(AViewType::rank == 2, "KokkosBatched::ger: AViewType must have rank 2."); | ||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
const int lda = A.extent_int(0), n = A.extent_int(1); | ||
const int m = x.extent_int(0); | ||
if (m < 0) { | ||
Kokkos::printf( | ||
"KokkosBatched::ger: input parameter m must not be less than 0: m " | ||
"= " | ||
"%d\n", | ||
m); | ||
return 1; | ||
} | ||
|
||
if (n < 0) { | ||
Kokkos::printf( | ||
"KokkosBatched::ger: input parameter n must not be less than 0: n " | ||
"= " | ||
"%d\n", | ||
n); | ||
return 1; | ||
} | ||
|
||
if (y.extent_int(0) != n) { | ||
Kokkos::printf( | ||
"KokkosBatched::ger: y must contain n elements: n " | ||
"= " | ||
"%d\n", | ||
n); | ||
return 1; | ||
} | ||
|
||
if (lda < Kokkos::max(1, m)) { | ||
Kokkos::printf( | ||
"KokkosBatched::ger: leading dimension of A must not be smaller than " | ||
"max(1, m): " | ||
"lda = %d, m = %d\n", | ||
lda, m); | ||
return 1; | ||
} | ||
#endif | ||
return 0; | ||
} | ||
} // namespace Impl | ||
|
||
// T | ||
// A: alpha * x * y**T + A | ||
template <> | ||
struct SerialGer<Trans::Transpose> { | ||
template <typename ScalarType, typename XViewType, typename YViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const YViewType &y, | ||
const AViewType &A) { | ||
// Quick return if possible | ||
const int m = A.extent_int(0), n = A.extent_int(1); | ||
if (m == 0 || n == 0 || (alpha == ScalarType(0))) return 0; | ||
|
||
auto info = Impl::checkGerInput(x, y, A); | ||
if (info) return info; | ||
|
||
return Impl::SerialGerInternal::invoke(KokkosBlas::Impl::OpID(), m, n, alpha, x.data(), x.stride(0), y.data(), | ||
y.stride(0), A.data(), A.stride(0), A.stride(1)); | ||
} | ||
}; | ||
|
||
// C | ||
// A: alpha * x * y**H + A | ||
template <> | ||
struct SerialGer<Trans::ConjTranspose> { | ||
template <typename ScalarType, typename XViewType, typename YViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const YViewType &y, | ||
const AViewType &A) { | ||
// Quick return if possible | ||
const int m = A.extent_int(0), n = A.extent_int(1); | ||
if (m == 0 || n == 0 || (alpha == ScalarType(0))) return 0; | ||
|
||
auto info = Impl::checkGerInput(x, y, A); | ||
if (info) return info; | ||
|
||
return Impl::SerialGerInternal::invoke(KokkosBlas::Impl::OpConj(), m, n, alpha, x.data(), x.stride(0), y.data(), | ||
y.stride(0), A.data(), A.stride(0), A.stride(1)); | ||
} | ||
}; | ||
|
||
} // namespace KokkosBatched | ||
|
||
#endif // KOKKOSBATCHED_GER_SERIAL_IMPL_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
|
||
#ifndef KOKKOSBATCHED_GER_SERIAL_INTERNAL_HPP_ | ||
#define KOKKOSBATCHED_GER_SERIAL_INTERNAL_HPP_ | ||
|
||
#include <KokkosBatched_Util.hpp> | ||
|
||
namespace KokkosBatched { | ||
namespace Impl { | ||
|
||
/// | ||
/// Serial Internal Impl | ||
/// ==================== | ||
|
||
struct SerialGerInternal { | ||
template <typename Op, typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const int am, const int an, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT x, const int xs0, | ||
const ValueType *KOKKOS_RESTRICT y, const int ys0, | ||
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1); | ||
}; | ||
|
||
template <typename Op, typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialGerInternal::invoke(Op op, const int am, const int an, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT x, const int xs0, | ||
const ValueType *KOKKOS_RESTRICT y, const int ys0, | ||
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { | ||
for (int j = 0; j < an; j++) { | ||
if (y[j * ys0] != 0) { | ||
auto temp = alpha * op(y[j * ys0]); | ||
for (int i = 0; i < am; i++) { | ||
A[i * as0 + j * as1] += x[i * xs0] * temp; | ||
} | ||
} | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
} // namespace Impl | ||
} // namespace KokkosBatched | ||
|
||
#endif // KOKKOSBATCHED_GER_SERIAL_INTERNAL_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
#ifndef KOKKOSBATCHED_GER_HPP_ | ||
#define KOKKOSBATCHED_GER_HPP_ | ||
|
||
#include <KokkosBatched_Util.hpp> | ||
|
||
/// \author Yuuichi Asahi ([email protected]) | ||
|
||
namespace KokkosBatched { | ||
|
||
/// \brief Serial Batched Ger: | ||
/// Performs the rank 1 operation | ||
/// A := alpha*x*y**T + A or A := alpha*x*y**H + A | ||
/// where alpha is a scalar, x is an m element vector, y is an n element | ||
/// vector and A is an m by n matrix. | ||
/// | ||
/// \tparam ScalarType: Input type for the scalar alpha | ||
/// \tparam XViewType: Input type for the vector x, needs to be a 1D view | ||
/// \tparam YViewType: Input type for the vector y, needs to be a 1D view | ||
/// \tparam AViewType: Input/output type for the matrix A, needs to be a 2D view | ||
/// | ||
/// \param alpha [in]: A is a m by n general matrix, a rank 2 view | ||
/// \param x [in]: x is a length m vector, a rank 1 view | ||
/// \param y [in]: y is a length n vector, a rank 1 view | ||
/// \param A [inout]: A is a m by n matrix, a rank 2 view | ||
/// | ||
/// No nested parallel_for is used inside of the function. | ||
/// | ||
template <typename ArgTrans> | ||
struct SerialGer { | ||
template <typename ScalarType, typename XViewType, typename YViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const YViewType &y, | ||
const AViewType &a); | ||
}; | ||
} // namespace KokkosBatched | ||
|
||
#include "KokkosBatched_Ger_Serial_Impl.hpp" | ||
|
||
#endif // KOKKOSBATCHED_GER_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.