1515#define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
1616#define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
1717#define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
18+ #define LDST128BITS (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
1819
1920// -------------------------------------- FP32 --------------------------------------
2021// Warp Reduce Sum
@@ -123,7 +124,7 @@ __global__ void dot_prod_f16_f32_kernel(half* a, half* b, float* y, int N) {
123124 if (tid == 0 ) atomicAdd (y, prod);
124125}
125126
126- template <const int NUM_THREADS = 256 >
127+ template <const int NUM_THREADS = 256 / 2 >
127128__global__ void dot_prod_f16x2_f32_kernel (half* a, half* b, float * y, int N) {
128129 int tid = threadIdx .x ;
129130 int idx = (blockIdx .x * NUM_THREADS + tid) * 2 ; // 2 half elements per thread
@@ -148,6 +149,38 @@ __global__ void dot_prod_f16x2_f32_kernel(half* a, half* b, float* y, int N) {
148149 if (tid == 0 ) atomicAdd (y, prod);
149150}
150151
152+ template <const int NUM_THREADS = 256 /8 >
153+ __global__ void dot_prod_f16x8_pack_f32_kernel (half* a, half* b, float * y, int N) {
154+ int tid = threadIdx .x ;
155+ int idx = (blockIdx .x * NUM_THREADS + tid) * 8 ; // 8 half elements per thread
156+ constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1 ) / WARP_SIZE;
157+ __shared__ float reduce_smem[NUM_WARPS];
158+ // temporary register(memory), .local space in ptx, addressable
159+ half pack_a[8 ], pack_b[8 ]; // 8x16 bits=128 bits.
160+ LDST128BITS (pack_a[0 ]) = LDST128BITS (a[idx]); // load 128 bits
161+ LDST128BITS (pack_b[0 ]) = LDST128BITS (b[idx]); // load 128 bits
162+ const half z = __float2half (0 .0f );
163+
164+ half prod_f16 = z;
165+ #pragma unroll
166+ for (int i = 0 ; i < 8 ; i += 2 ) {
167+ half2 v = __hmul2 (HALF2 (pack_a[i]), HALF2 (pack_b[i]));
168+ prod_f16 += (((idx + i ) < N) ? (v.x + v.y ) : z);
169+ }
170+
171+ int warp = tid / WARP_SIZE;
172+ int lane = tid % WARP_SIZE;
173+ // perform warp sync reduce.
174+ float prod = warp_reduce_sum_f16_f32<WARP_SIZE>(prod_f16);
175+ // warp leaders store the data to shared memory.
176+ if (lane == 0 ) reduce_smem[warp] = prod;
177+ __syncthreads (); // make sure the data is in shared memory.
178+ // the first warp compute the final sum.
179+ prod = (lane < NUM_WARPS) ? reduce_smem[lane] : 0 .0f ;
180+ if (warp == 0 ) prod = warp_reduce_sum_f32<NUM_WARPS>(prod);
181+ if (tid == 0 ) atomicAdd (y, prod);
182+ }
183+
151184// --------------------- PyTorch bindings for custom kernel -----------------------
152185#define STRINGFY (str ) #str
153186#define TORCH_BINDING_COMMON_EXTENSION (func ) \
@@ -159,8 +192,42 @@ if(((T).options().dtype() != (th_type))) { \
159192 throw std::runtime_error (" values must be " #th_type); \
160193}
161194
162- #define CHECK_TORCH_TENSOR_SHAPE (T, S0 ) \
163- if (((T).size(0 ) != (S0))) { throw std::runtime_error (" Tensor size mismatch!" ); }
195+ #define LANUCH_DOT_PROD_KERNEL (NT, packed_type, acc_type, element_type ) \
196+ dot_prod_##packed_type##_##acc_type##_kernel<(NT)><<<grid, block>>> ( \
197+ reinterpret_cast <element_type*>(a.data_ptr()), \
198+ reinterpret_cast <element_type*>(b.data_ptr()), \
199+ prod.data_ptr<float >(), N);
200+
201+ #define DISPATCH_DOT_PROD_KERNEL (K, packed_type, acc_type, element_type, n_elements ) \
202+ const int NT = (K)/(n_elements); \
203+ dim3 block (NT); \
204+ dim3 grid ((S)); \
205+ switch (NT) \
206+ { \
207+ case 32 : \
208+ LANUCH_DOT_PROD_KERNEL (32 , packed_type, acc_type, element_type) \
209+ break ; \
210+ case 64 : \
211+ LANUCH_DOT_PROD_KERNEL (64 , packed_type, acc_type, element_type) \
212+ break ; \
213+ case 128 : \
214+ LANUCH_DOT_PROD_KERNEL (128 , packed_type, acc_type, element_type) \
215+ break ; \
216+ case 256 : \
217+ LANUCH_DOT_PROD_KERNEL (256 , packed_type, acc_type, element_type) \
218+ break ; \
219+ case 512 : \
220+ LANUCH_DOT_PROD_KERNEL (512 , packed_type, acc_type, element_type) \
221+ break ; \
222+ case 1024 : \
223+ LANUCH_DOT_PROD_KERNEL (1024 , packed_type, acc_type, element_type) \
224+ break ; \
225+ default : \
226+ throw std::runtime_error ( \
227+ " only support (K)/(n_elements): 32/64/128/256/512/1024" ); \
228+ break ; \
229+ }
230+
164231
165232#define TORCH_BINDING_DOT_PROD (packed_type, acc_type, th_type, element_type, n_elements ) \
166233torch::Tensor dot_prod_##packed_type##_##acc_type(torch::Tensor a, torch::Tensor b) { \
@@ -169,30 +236,49 @@ torch::Tensor dot_prod_##packed_type##_##acc_type(torch::Tensor a, torch::Tensor
169236 auto options = torch::TensorOptions ().dtype (torch::kFloat32 ).device ( \
170237 torch::kCUDA , 0 ); \
171238 auto prod = torch::zeros ({1 }, options); \
172- const int N = a.size ( 0 ); \
173- CHECK_TORCH_TENSOR_SHAPE (b, N) \
174- static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
175- const int NUM_BLOCKS = (N + 256 - 1 ) / 256 ; \
176- dim3 block (NUM_THREADS_PER_BLOCK); \
177- dim3 grid (NUM_BLOCKS); \
178- dot_prod_##packed_type##_##acc_type##_kernel< \
179- NUM_THREADS_PER_BLOCK ><<<grid, block>>> ( \
239+ const int ndim = a.dim (); \
240+ if (ndim != 2 ) { \
241+ int N = 1 ; \
242+ for ( int i = 0 ; i < ndim; ++i) { N *= a. size (i); } \
243+ dim3 block (256 ); \
244+ dim3 grid (((N + 256 - 1 ) / 256 ) / (n_elements)); \
245+ dot_prod_##packed_type##_##acc_type##_kernel< \
246+ 256 ><<<grid, block>>> ( \
180247 reinterpret_cast <element_type*>(a.data_ptr ()), \
181248 reinterpret_cast <element_type*>(b.data_ptr ()), \
182249 prod.data_ptr <float >(), N); \
250+ } else { \
251+ const int S = a.size (0 ); \
252+ const int K = a.size (1 ); \
253+ const int N = S * K; \
254+ if ((K/(n_elements)) <= 1024 ) { \
255+ DISPATCH_DOT_PROD_KERNEL (K, packed_type, acc_type, element_type, n_elements) \
256+ } else { \
257+ int N = 1 ; \
258+ for (int i = 0 ; i < ndim; ++i) { N *= a.size (i); } \
259+ dim3 block (256 ); \
260+ dim3 grid (((N + 256 - 1 ) / 256 ) / (n_elements)); \
261+ dot_prod_##packed_type##_##acc_type##_kernel< \
262+ 256 ><<<grid, block>>> ( \
263+ reinterpret_cast <element_type*>(a.data_ptr ()), \
264+ reinterpret_cast <element_type*>(b.data_ptr ()), \
265+ prod.data_ptr <float >(), N); \
266+ } \
267+ } \
183268 return prod; \
184269}
185270
186271// packed_type, acc_type, th_type, element_type, n_elements_per_pack
187- TORCH_BINDING_DOT_PROD (f32 , f32 , torch::kFloat32 , float , 1 )
188- TORCH_BINDING_DOT_PROD(f32x4, f32 , torch::kFloat32 , float , 4 )
189- TORCH_BINDING_DOT_PROD(f16 , f32 , torch::kHalf , half, 1 )
190- TORCH_BINDING_DOT_PROD(f16x2, f32 , torch::kHalf , half, 2 )
191-
272+ TORCH_BINDING_DOT_PROD (f32 , f32 , torch::kFloat32 , float , 1 )
273+ TORCH_BINDING_DOT_PROD(f32x4, f32 , torch::kFloat32 , float , 4 )
274+ TORCH_BINDING_DOT_PROD(f16 , f32 , torch::kHalf , half, 1 )
275+ TORCH_BINDING_DOT_PROD(f16x2, f32 , torch::kHalf , half, 2 )
276+ TORCH_BINDING_DOT_PROD(f16x8_pack, f32 , torch:: kHalf , half, 8 )
192277
193278PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
194279 TORCH_BINDING_COMMON_EXTENSION (dot_prod_f32_f32)
195280 TORCH_BINDING_COMMON_EXTENSION (dot_prod_f32x4_f32)
196281 TORCH_BINDING_COMMON_EXTENSION (dot_prod_f16_f32)
197282 TORCH_BINDING_COMMON_EXTENSION (dot_prod_f16x2_f32)
283+ TORCH_BINDING_COMMON_EXTENSION (dot_prod_f16x8_pack_f32)
198284}
0 commit comments