Skip to content

Commit 40a7c0f

Browse files
committed
Adding basic sparse multiplication kernel for default CPU and GPU
1 parent aa4a7a3 commit 40a7c0f

File tree

7 files changed

+626
-8
lines changed

7 files changed

+626
-8
lines changed

DefaultCPU/sp_gemm.hh

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#pragma once
2+
3+
#if defined CPU_DEFAULT
4+
5+
#include "../include/kernels/CPU/sp_gemm.hh"
6+
#include "../include/utilities.hh"
7+
8+
namespace cpu {
9+
/** A class for GEMM CPU BLAS kernels. */
10+
template <typename T>
11+
class sp_gemm_cpu : public sp_gemm<T> {
12+
public:
13+
using sp_gemm<T>::sp_gemm;
14+
using sp_gemm<T>::callConsume;
15+
using sp_gemm<T>::m_;
16+
using sp_gemm<T>::n_;
17+
using sp_gemm<T>::k_;
18+
using sp_gemm<T>::A_;
19+
using sp_gemm<T>::B_;
20+
using sp_gemm<T>::C_;
21+
22+
private:
23+
/** Perform the GEMM kernel. */
24+
void callGemm() override {
25+
/** A naive implementation of a column-major GEMM. Alpha and Beta are always
26+
* 1 and 0 respectively.
27+
* Operation takes the form of C[M,N] = A[M,K] * B[K,N].
28+
* callConsume() is required to ensure that the compiler does not optimise
29+
* away this function. */
30+
int x, y, z;
31+
T acc;
32+
for (x = 0; x < m_; x++) {
33+
for (y = 0; y < n_; y++) {
34+
acc = 0.0;
35+
for (z = 0; z < k_; z++) {
36+
acc += A_[z * m_ + x] * B_[y * k_ + z];
37+
}
38+
C_[y * m_ + x] = acc;
39+
}
40+
}
41+
// Ensure compiler doesn't optimise away the work being done
42+
callConsume();
43+
}
44+
45+
/** Perform any required steps before calling the GEMM kernel that should
46+
* be timed. */
47+
void preLoopRequirements() override {}
48+
49+
/** Perform any required steps after calling the GEMM kernel that should
50+
* be timed. */
51+
void postLoopRequirements() override {}
52+
};
53+
54+
} // namespace cpu
55+
#endif

DefaultGPU/sp_gemm.hh

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#pragma once
2+
3+
#if defined GPU_DEFAULT
4+
5+
#include <cmath>
6+
7+
#include "../include/kernels/GPU/sp_gemm.hh"
8+
#include "../include/utilities.hh"
9+
10+
namespace gpu {
11+
/** A class for GEMM GPU BLAS kernels. */
12+
template <typename T>
13+
class sp_gemm_gpu : public sp_gemm<T> {
14+
public:
15+
using sp_gemm<T>::sp_gemm;
16+
17+
/** Call the BLAS kernel n times, with 1 warmup run.
18+
* Returns the time elapsed for n BLAS calls in seconds. */
19+
time_checksum_gflop compute() {
20+
// Override function in base `kernel` class as DefaultGPU should do nothing.
21+
return {INFINITY, INFINITY, 0.0};
22+
}
23+
24+
/** Initialise the required data structures. */
25+
void initialise(gpuOffloadType offload, int m, int n, int k) override {
26+
// Default GPU implementation - do nothing.
27+
}
28+
29+
private:
30+
/** Make a call to the BLAS Library Kernel. */
31+
void callGemm() override {
32+
// Default GPU implementation - do nothing.
33+
}
34+
35+
/** Perform any required steps before calling the GEMM kernel that should
36+
* be timed. */
37+
void preLoopRequirements() override {
38+
// Default GPU implementation - do nothing.
39+
}
40+
41+
/** Perform any required steps after calling the GEMM kernel that should
42+
* be timed. */
43+
void postLoopRequirements() override {
44+
// Default GPU implementation - do nothing.
45+
}
46+
47+
/** Do any necessary cleanup (free pointers, close library handles, etc.)
48+
* after Kernel has been called. */
49+
void postCallKernelCleanup() override {
50+
// Default GPU implementation - do nothing.
51+
}
52+
};
53+
} // namespace gpu
54+
#endif

0 commit comments

Comments
 (0)