Skip to content

Commit bdcb8f4

Browse files
CUDA: int8 tensor cores for MMQ (q4_K, q5_K, q6_K) (ggml-org#7860)
1 parent c2ce6c4 commit bdcb8f4

File tree

2 files changed

+360
-6
lines changed

2 files changed

+360
-6
lines changed

ggml-cuda/mma.cuh

+66
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,27 @@
11
#include "common.cuh"
22

3+
struct mma_int_A_I16K4 {
4+
static constexpr int I = 16;
5+
static constexpr int K = 4;
6+
static constexpr int ne = 2;
7+
8+
int x[ne] = {0};
9+
10+
static __device__ __forceinline__ int get_i(const int l) {
11+
const int ret = (l%2) * (I/2) + threadIdx.x / K;
12+
GGML_CUDA_ASSUME(ret >= 0);
13+
GGML_CUDA_ASSUME(ret < I);
14+
return ret;
15+
}
16+
17+
static __device__ __forceinline__ int get_k(const int /* l */) {
18+
const int ret = threadIdx.x % K;
19+
GGML_CUDA_ASSUME(ret >= 0);
20+
GGML_CUDA_ASSUME(ret < K);
21+
return ret;
22+
}
23+
};
24+
325
struct mma_int_A_I16K8 {
426
static constexpr int I = 16;
527
static constexpr int K = 8;
@@ -22,6 +44,28 @@ struct mma_int_A_I16K8 {
2244
}
2345
};
2446

47+
struct mma_int_B_J8K4 {
48+
static constexpr int J = 8;
49+
static constexpr int K = 4;
50+
static constexpr int ne = 1;
51+
52+
int x[ne] = {0};
53+
54+
static __device__ __forceinline__ int get_j(const int /* l */) {
55+
const int ret = threadIdx.x / K;
56+
GGML_CUDA_ASSUME(ret >= 0);
57+
GGML_CUDA_ASSUME(ret < J);
58+
return ret;
59+
}
60+
61+
static __device__ __forceinline__ int get_k(const int /* l */) {
62+
const int ret = threadIdx.x % K;
63+
GGML_CUDA_ASSUME(ret >= 0);
64+
GGML_CUDA_ASSUME(ret < K);
65+
return ret;
66+
}
67+
};
68+
2569
struct mma_int_B_J8K8 {
2670
static constexpr int J = 8;
2771
static constexpr int K = 8;
@@ -65,6 +109,28 @@ struct mma_int_C_I16J8 {
65109
return ret;
66110
}
67111

112+
__device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
113+
#ifdef INT8_MMA_AVAILABLE
114+
#if __CUDA_ARCH__ >= CC_AMPERE
115+
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
116+
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
117+
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
118+
#else
119+
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
120+
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
121+
: "+r"(x[0]), "+r"(x[1])
122+
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
123+
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
124+
: "+r"(x[2]), "+r"(x[3])
125+
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
126+
#endif // __CUDA_ARCH__ >= CC_AMPERE
127+
#else
128+
GGML_UNUSED(mma_A);
129+
GGML_UNUSED(mma_B);
130+
NO_DEVICE_CODE;
131+
#endif // INT8_MMA_AVAILABLE
132+
}
133+
68134
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
69135
#ifdef INT8_MMA_AVAILABLE
70136
#if __CUDA_ARCH__ >= CC_AMPERE

0 commit comments

Comments
 (0)