1
1
#include " common.cuh"
2
2
3
+ struct mma_int_A_I16K4 {
4
+ static constexpr int I = 16 ;
5
+ static constexpr int K = 4 ;
6
+ static constexpr int ne = 2 ;
7
+
8
+ int x[ne] = {0 };
9
+
10
+ static __device__ __forceinline__ int get_i (const int l) {
11
+ const int ret = (l%2 ) * (I/2 ) + threadIdx .x / K;
12
+ GGML_CUDA_ASSUME (ret >= 0 );
13
+ GGML_CUDA_ASSUME (ret < I);
14
+ return ret;
15
+ }
16
+
17
+ static __device__ __forceinline__ int get_k (const int /* l */ ) {
18
+ const int ret = threadIdx .x % K;
19
+ GGML_CUDA_ASSUME (ret >= 0 );
20
+ GGML_CUDA_ASSUME (ret < K);
21
+ return ret;
22
+ }
23
+ };
24
+
3
25
struct mma_int_A_I16K8 {
4
26
static constexpr int I = 16 ;
5
27
static constexpr int K = 8 ;
@@ -22,6 +44,28 @@ struct mma_int_A_I16K8 {
22
44
}
23
45
};
24
46
47
+ struct mma_int_B_J8K4 {
48
+ static constexpr int J = 8 ;
49
+ static constexpr int K = 4 ;
50
+ static constexpr int ne = 1 ;
51
+
52
+ int x[ne] = {0 };
53
+
54
+ static __device__ __forceinline__ int get_j (const int /* l */ ) {
55
+ const int ret = threadIdx .x / K;
56
+ GGML_CUDA_ASSUME (ret >= 0 );
57
+ GGML_CUDA_ASSUME (ret < J);
58
+ return ret;
59
+ }
60
+
61
+ static __device__ __forceinline__ int get_k (const int /* l */ ) {
62
+ const int ret = threadIdx .x % K;
63
+ GGML_CUDA_ASSUME (ret >= 0 );
64
+ GGML_CUDA_ASSUME (ret < K);
65
+ return ret;
66
+ }
67
+ };
68
+
25
69
struct mma_int_B_J8K8 {
26
70
static constexpr int J = 8 ;
27
71
static constexpr int K = 8 ;
@@ -65,6 +109,28 @@ struct mma_int_C_I16J8 {
65
109
return ret;
66
110
}
67
111
112
+ __device__ __forceinline__ void mma_K4 (const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
113
+ #ifdef INT8_MMA_AVAILABLE
114
+ #if __CUDA_ARCH__ >= CC_AMPERE
115
+ asm (" mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
116
+ : " +r" (x[0 ]), " +r" (x[1 ]), " +r" (x[2 ]), " +r" (x[3 ])
117
+ : " r" (mma_A.x [0 ]), " r" (mma_A.x [1 ]), " r" (mma_B.x [0 ]));
118
+ #else
119
+ // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
120
+ asm (" mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
121
+ : " +r" (x[0 ]), " +r" (x[1 ])
122
+ : " r" (mma_A.x [0 ]), " r" (mma_B.x [0 ]));
123
+ asm (" mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
124
+ : " +r" (x[2 ]), " +r" (x[3 ])
125
+ : " r" (mma_A.x [1 ]), " r" (mma_B.x [0 ]));
126
+ #endif // __CUDA_ARCH__ >= CC_AMPERE
127
+ #else
128
+ GGML_UNUSED (mma_A);
129
+ GGML_UNUSED (mma_B);
130
+ NO_DEVICE_CODE;
131
+ #endif // INT8_MMA_AVAILABLE
132
+ }
133
+
68
134
__device__ __forceinline__ void mma_K8 (const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
69
135
#ifdef INT8_MMA_AVAILABLE
70
136
#if __CUDA_ARCH__ >= CC_AMPERE
0 commit comments