1+ #include < stdio.h>
2+ #include < stdlib.h>
3+ #include < float.h>
4+ #include < vector>
5+ #include < algorithm>
6+ #include < cuda_runtime.h>
7+ #include < cuda_fp16.h>
8+ #include < cuda_bf16.h>
9+ #include < cuda_fp8.h>
10+ #include < torch/types.h>
11+ #include < torch/extension.h>
12+
13+ #define WARP_SIZE 32
14+ #define INT4 (value ) (reinterpret_cast <int4 *>(&(value))[0 ])
15+ #define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
16+ #define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
17+ #define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
18+
19+ // -------------------------------------- FP16 --------------------------------------
20+ // Warp Reduce Sum
21+ template <const int kWarpSize = WARP_SIZE>
22+ __device__ __forceinline__ half warp_reduce_sum_f16 (half val) {
23+ #pragma unroll
24+ for (int mask = kWarpSize >> 1 ; mask >= 1 ; mask >>= 1 ) {
25+ val += __shfl_xor_sync (0xffffffff , val, mask);
26+ }
27+ return val;
28+ }
29+
30+ // HGEMV: Warp HGEMV K32
31+ // 假设K为32的倍数,每个warp负责一行
32+ // grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4
33+ // a: MxK, x: Kx1, y: Mx1, compute: y = a * x
34+ __global__ void hgemv_k32_f16_kernel (half* a, half* x, half* y, int M, int K) {
35+ int tx = threadIdx .x ; // 0~31
36+ int ty = threadIdx .y ; // 0~4
37+ int bx = blockIdx .x ; // 0~M/4
38+ int lane = tx % WARP_SIZE; // 0~31
39+ int m = bx * blockDim .y + ty; // (0~M/4) * 4 + (0~3)
40+ if (m < M) {
41+ half sum = 0 .0f ;
42+ int NUM_WARPS = (K + WARP_SIZE - 1 ) / WARP_SIZE;
43+ #pragma unroll
44+ for (int w = 0 ; w < NUM_WARPS; ++w) {
45+ // 若NUM_WARPS>=2,先将当前行的数据累加到第一个warp中
46+ int k = w * WARP_SIZE + lane;
47+ sum += a[m * K + k] * x[k];
48+ }
49+ sum = warp_reduce_sum_f16<WARP_SIZE>(sum);
50+ if (lane == 0 ) y[m] = sum;
51+ }
52+ }
53+
54+ // HGEMV: Warp HGEMV K128 + half2x2
55+ // 假设K为128的倍数 float4
56+ // grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4
57+ // a: MxK, x: Kx1, y: Mx1, compute: y = a * x
58+ __global__ void hgemv_k128_f16x4_kernel (half* a, half* x, half* y, int M, int K) {
59+ // 每个线程负责4个元素,一个warp覆盖128个元素
60+ int tx = threadIdx .x ; // 0~31
61+ int ty = threadIdx .y ; // 0~3
62+ int bx = blockIdx .x ; // 0~M/4
63+ int lane = tx % WARP_SIZE; // 0~31
64+ int m = blockDim .y * bx + ty; // (0~M/4) * 4 + (0~3)
65+
66+ if (m < M) {
67+ half sum = 0 .0f ;
68+ // process 4*WARP_SIZE elements per warp.
69+ int NUM_WARPS = (((K + WARP_SIZE - 1 ) / WARP_SIZE) + 4 - 1 ) / 4 ;
70+ #pragma unroll
71+ for (int w = 0 ; w < NUM_WARPS; ++w) {
72+ int k = (w * WARP_SIZE + lane) * 4 ;
73+ half2 reg_x_0 = HALF2 (x[k + 0 ]);
74+ half2 reg_x_1 = HALF2 (x[k + 2 ]);
75+ half2 reg_a_0 = HALF2 (a[m * K + k + 0 ]);
76+ half2 reg_a_1 = HALF2 (a[m * K + k + 2 ]);
77+ sum += (reg_x_0.x * reg_a_0.x + reg_x_0.y * reg_a_0.y
78+ + reg_x_1.x * reg_a_1.x + reg_x_1.y * reg_a_1.y );
79+ }
80+ sum = warp_reduce_sum_f16<WARP_SIZE>(sum);
81+ if (lane == 0 ) y[m] = sum;
82+ }
83+ }
84+
85+ // HGEMV: Warp HGEMV K16
86+ // 假设K为16 < 32,每个warp负责2行,每行有16个元素
87+ // NUM_THREADS=128, NUM_WARPS=NUM_THREADS/WARP_SIZE;
88+ // NUM_ROWS=NUM_WARPS * ROW_PER_WARP, grid(M/NUM_ROWS), block(32,NUM_WARPS)
89+ // a: MxK, x: Kx1, y: Mx1, compute: y = a * x
90+ template <const int ROW_PER_WARP = 2 >
91+ __global__ void hgemv_k16_f16_kernel (half* A, half* x, half* y, int M, int K) {
92+ constexpr int K_WARP_SIZE = (WARP_SIZE + ROW_PER_WARP - 1 ) / ROW_PER_WARP;
93+ int tx = threadIdx .x ; // 0~31
94+ int ty = threadIdx .y ; // 0~NUM_WARPS
95+ int bx = blockIdx .x ; // 0~M/NUM_ROWS (NUM_ROWS=NUM_WARPS * ROW_PER_WARP)
96+ int lane = tx % WARP_SIZE; // 0~31
97+ int k = lane % K_WARP_SIZE; // 0~15
98+ // gloabl row of a: MxK and y:Mx1, blockDim.y=NUM_WARPS
99+ int m = (blockDim .y * bx + ty) * ROW_PER_WARP + lane / K_WARP_SIZE;
100+ if (m < M) {
101+ half sum = A[m * K + k] * x[k];
102+ sum = warp_reduce_sum_f16<K_WARP_SIZE>(sum);
103+ // 注意是k == 0,而不是lane == 0
104+ if (k == 0 ) y[m] = sum;
105+ }
106+ }
107+
108+ // --------------------- PyTorch bindings for custom kernel -----------------------
109+ #define STRINGFY (str ) #str
110+ #define TORCH_BINDING_COMMON_EXTENSION (func ) \
111+ m.def(STRINGFY(func), &func, STRINGFY(func));
112+
113+ #define CHECK_TORCH_TENSOR_DTYPE (T, th_type ) \
114+ if (((T).options().dtype() != (th_type))) { \
115+ std::cout << " Tensor Info:" << (T).options () << std::endl; \
116+ throw std::runtime_error (" values must be " #th_type); \
117+ }
118+
119+ #define CHECK_TORCH_TENSOR_SHAPE (T, S0, S1 ) \
120+ if (((T).size(0 ) != (S0)) || ((T).size(1 ) != (S1))) { \
121+ throw std::runtime_error (" Tensor size mismatch!" ); \
122+ }
123+
124+ #define ASSERT_K_IS_MULTIBLE_OF (V ) \
125+ if (K % (V) != 0 ) { throw std::runtime_error (" K must be multiple of " #V); }
126+
127+ #define ASSERT_K_IS_EQUAL_OF (V ) \
128+ if (K != (V)) { throw std::runtime_error (" K must be " #V);}
129+
130+ void hgemv_k32_f16 (torch::Tensor a, torch::Tensor x, torch::Tensor y) {
131+ CHECK_TORCH_TENSOR_DTYPE (a, torch::kHalf )
132+ CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
133+ CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
134+ const int M = a.size (0 );
135+ const int K = a.size (1 );
136+ CHECK_TORCH_TENSOR_SHAPE (a, M, K)
137+ CHECK_TORCH_TENSOR_SHAPE (x, K, 1 )
138+ CHECK_TORCH_TENSOR_SHAPE (y, M, 1 )
139+ ASSERT_K_IS_MULTIBLE_OF (32 )
140+
141+ dim3 block (32 , 4 );
142+ dim3 grid ((M + 4 - 1 ) / 4 );
143+
144+ hgemv_k32_f16_kernel<<<grid, block>>> (
145+ reinterpret_cast <half*>(a.data_ptr ()),
146+ reinterpret_cast <half*>(x.data_ptr ()),
147+ reinterpret_cast <half*>(y.data_ptr ()),
148+ M, K
149+ );
150+ }
151+
152+ void hgemv_k128_f16x4 (torch::Tensor a, torch::Tensor x, torch::Tensor y) {
153+ CHECK_TORCH_TENSOR_DTYPE (a, torch::kHalf )
154+ CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
155+ CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
156+ const int M = a.size (0 );
157+ const int K = a.size (1 );
158+ CHECK_TORCH_TENSOR_SHAPE (a, M, K)
159+ CHECK_TORCH_TENSOR_SHAPE (x, K, 1 )
160+ CHECK_TORCH_TENSOR_SHAPE (y, M, 1 )
161+ ASSERT_K_IS_MULTIBLE_OF (128 )
162+
163+ dim3 block (32 , 4 );
164+ dim3 grid ((M + 4 - 1 ) / 4 );
165+
166+ hgemv_k128_f16x4_kernel<<<grid, block>>> (
167+ reinterpret_cast <half*>(a.data_ptr ()),
168+ reinterpret_cast <half*>(x.data_ptr ()),
169+ reinterpret_cast <half*>(y.data_ptr ()),
170+ M, K
171+ );
172+ }
173+
174+ void hgemv_k16_f16 (torch::Tensor a, torch::Tensor x, torch::Tensor y) {
175+ CHECK_TORCH_TENSOR_DTYPE (a, torch::kHalf )
176+ CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
177+ CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
178+ const int M = a.size (0 );
179+ const int K = a.size (1 );
180+ CHECK_TORCH_TENSOR_SHAPE (a, M, K)
181+ CHECK_TORCH_TENSOR_SHAPE (x, K, 1 )
182+ CHECK_TORCH_TENSOR_SHAPE (y, M, 1 )
183+ ASSERT_K_IS_EQUAL_OF (16 )
184+
185+ constexpr int NUM_THREADS = 128 ;
186+ constexpr int ROW_PER_WARP = 2 ;
187+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; // 4
188+ constexpr int NUM_ROWS = NUM_WARPS * ROW_PER_WARP; // 4 * 2 = 8
189+
190+ dim3 block (32 , NUM_WARPS);
191+ dim3 grid ((M + NUM_ROWS - 1 ) / NUM_ROWS);
192+
193+ hgemv_k16_f16_kernel<ROW_PER_WARP><<<grid, block>>> (
194+ reinterpret_cast <half*>(a.data_ptr ()),
195+ reinterpret_cast <half*>(x.data_ptr ()),
196+ reinterpret_cast <half*>(y.data_ptr ()),
197+ M, K
198+ );
199+ }
200+
201+ PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
202+ TORCH_BINDING_COMMON_EXTENSION (hgemv_k32_f16)
203+ TORCH_BINDING_COMMON_EXTENSION (hgemv_k128_f16x4)
204+ TORCH_BINDING_COMMON_EXTENSION (hgemv_k16_f16)
205+ }
0 commit comments