Skip to content

Commit

Permalink
implement batched serial ger (kokkos#2491)
Browse files Browse the repository at this point in the history
* implement batched serial ger

Signed-off-by: Yuuichi Asahi <[email protected]>

* fixx: view constructors for x and y in ger test

Signed-off-by: Yuuichi Asahi <[email protected]>

---------

Signed-off-by: Yuuichi Asahi <[email protected]>
Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Feb 5, 2025
1 parent 15d3ee3 commit 8912c6d
Show file tree
Hide file tree
Showing 5 changed files with 688 additions and 0 deletions.
118 changes: 118 additions & 0 deletions batched/dense/impl/KokkosBatched_Ger_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_GER_SERIAL_IMPL_HPP_
#define KOKKOSBATCHED_GER_SERIAL_IMPL_HPP_

#include <KokkosBatched_Util.hpp>
#include "KokkosBatched_Ger_Serial_Internal.hpp"

namespace KokkosBatched {
namespace Impl {
template <typename XViewType, typename YViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int checkGerInput([[maybe_unused]] const XViewType &x,
[[maybe_unused]] const YViewType &y,
[[maybe_unused]] const AViewType &A) {
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::ger: XViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<YViewType>, "KokkosBatched::ger: YViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::ger: AViewType is not a Kokkos::View.");
static_assert(XViewType::rank == 1, "KokkosBatched::ger: XViewType must have rank 1.");
static_assert(YViewType::rank == 1, "KokkosBatched::ger: YViewType must have rank 1.");
static_assert(AViewType::rank == 2, "KokkosBatched::ger: AViewType must have rank 2.");
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int lda = A.extent_int(0), n = A.extent_int(1);
const int m = x.extent_int(0);
if (m < 0) {
Kokkos::printf(
"KokkosBatched::ger: input parameter m must not be less than 0: m "
"= "
"%d\n",
m);
return 1;
}

if (n < 0) {
Kokkos::printf(
"KokkosBatched::ger: input parameter n must not be less than 0: n "
"= "
"%d\n",
n);
return 1;
}

if (y.extent_int(0) != n) {
Kokkos::printf(
"KokkosBatched::ger: y must contain n elements: n "
"= "
"%d\n",
n);
return 1;
}

if (lda < Kokkos::max(1, m)) {
Kokkos::printf(
"KokkosBatched::ger: leading dimension of A must not be smaller than "
"max(1, m): "
"lda = %d, m = %d\n",
lda, m);
return 1;
}
#endif
return 0;
}
} // namespace Impl

// T
// A: alpha * x * y**T + A
template <>
struct SerialGer<Trans::Transpose> {
template <typename ScalarType, typename XViewType, typename YViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const YViewType &y,
const AViewType &A) {
// Quick return if possible
const int m = A.extent_int(0), n = A.extent_int(1);
if (m == 0 || n == 0 || (alpha == ScalarType(0))) return 0;

auto info = Impl::checkGerInput(x, y, A);
if (info) return info;

return Impl::SerialGerInternal::invoke(KokkosBlas::Impl::OpID(), m, n, alpha, x.data(), x.stride(0), y.data(),
y.stride(0), A.data(), A.stride(0), A.stride(1));
}
};

// C
// A: alpha * x * y**H + A
template <>
struct SerialGer<Trans::ConjTranspose> {
template <typename ScalarType, typename XViewType, typename YViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const YViewType &y,
const AViewType &A) {
// Quick return if possible
const int m = A.extent_int(0), n = A.extent_int(1);
if (m == 0 || n == 0 || (alpha == ScalarType(0))) return 0;

auto info = Impl::checkGerInput(x, y, A);
if (info) return info;

return Impl::SerialGerInternal::invoke(KokkosBlas::Impl::OpConj(), m, n, alpha, x.data(), x.stride(0), y.data(),
y.stride(0), A.data(), A.stride(0), A.stride(1));
}
};

} // namespace KokkosBatched

#endif // KOKKOSBATCHED_GER_SERIAL_IMPL_HPP_
57 changes: 57 additions & 0 deletions batched/dense/impl/KokkosBatched_Ger_Serial_Internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_GER_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_GER_SERIAL_INTERNAL_HPP_

#include <KokkosBatched_Util.hpp>

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
/// ====================

struct SerialGerInternal {
template <typename Op, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const int am, const int an, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT x, const int xs0,
const ValueType *KOKKOS_RESTRICT y, const int ys0,
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1);
};

template <typename Op, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialGerInternal::invoke(Op op, const int am, const int an, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT x, const int xs0,
const ValueType *KOKKOS_RESTRICT y, const int ys0,
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) {
for (int j = 0; j < an; j++) {
if (y[j * ys0] != 0) {
auto temp = alpha * op(y[j * ys0]);
for (int i = 0; i < am; i++) {
A[i * as0 + j * as1] += x[i * xs0] * temp;
}
}
}

return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_GER_SERIAL_INTERNAL_HPP_
53 changes: 53 additions & 0 deletions batched/dense/src/KokkosBatched_Ger.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_GER_HPP_
#define KOKKOSBATCHED_GER_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi ([email protected])

namespace KokkosBatched {

/// \brief Serial Batched Ger:
/// Performs the rank 1 operation
/// A := alpha*x*y**T + A or A := alpha*x*y**H + A
/// where alpha is a scalar, x is an m element vector, y is an n element
/// vector and A is an m by n matrix.
///
/// \tparam ScalarType: Input type for the scalar alpha
/// \tparam XViewType: Input type for the vector x, needs to be a 1D view
/// \tparam YViewType: Input type for the vector y, needs to be a 1D view
/// \tparam AViewType: Input/output type for the matrix A, needs to be a 2D view
///
/// \param alpha [in]: A is a m by n general matrix, a rank 2 view
/// \param x [in]: x is a length m vector, a rank 1 view
/// \param y [in]: y is a length n vector, a rank 1 view
/// \param A [inout]: A is a m by n matrix, a rank 2 view
///
/// No nested parallel_for is used inside of the function.
///
template <typename ArgTrans>
struct SerialGer {
template <typename ScalarType, typename XViewType, typename YViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const YViewType &y,
const AViewType &a);
};
} // namespace KokkosBatched

#include "KokkosBatched_Ger_Serial_Impl.hpp"

#endif // KOKKOSBATCHED_GER_HPP_
1 change: 1 addition & 0 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include "Test_Batched_SerialIamax.hpp"
#include "Test_Batched_SerialGetrf.hpp"
#include "Test_Batched_SerialGetrs.hpp"
#include "Test_Batched_SerialGer.hpp"

// Team Kernels
#include "Test_Batched_TeamAxpy.hpp"
Expand Down
Loading

0 comments on commit 8912c6d

Please sign in to comment.