1414#define INT4 (value ) (reinterpret_cast <int4 *>(&(value))[0 ])
1515#define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
1616#define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
17+ #define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
18+ #define LDST128BITS (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
1719
1820// -------------------------------------- FP32 --------------------------------------
1921// Warp Reduce Sum
@@ -325,6 +327,55 @@ __global__ void layer_norm_f16_f32_kernel(half* x, half* y, float g, float b, in
325327 }
326328}
327329
330+ template <const int NUM_THREADS=256 >
331+ __global__ void layer_norm_f16x8_pack_f16_kernel (half* x, half* y, float g, float b, int N, int K) {
332+ int tid = threadIdx .x ; // 0..K-1
333+ int bid = blockIdx .x ; // 0..N-1
334+ int idx = (bid * blockDim .x + threadIdx .x ) * 8 ;
335+ const half epsilon = __float2half (1e-5f );
336+ const half g_ = __float2half (g);
337+ const half b_ = __float2half (b);
338+ const half K_ = __int2half_rn (K);
339+ const half z_ = __float2half (0 .0f );
340+
341+ __shared__ half s_mean; // shared within block
342+ __shared__ half s_variance; // shared within block
343+ // temporary register(memory), .local space in ptx, addressable
344+ half pack_x[8 ], pack_y[8 ]; // 8x16 bits=128 bits.
345+ // reinterpret as float4 and load 128 bits in 1 memory issue.
346+ LDST128BITS (pack_x[0 ]) = LDST128BITS (x[idx]); // load 128 bits
347+
348+ half value = z_;
349+ #pragma unroll
350+ for (int i = 0 ; i < 8 ; ++i) {
351+ value += ((idx + i) < N * K ? pack_x[i] : z_);
352+ }
353+ half sum = block_reduce_sum_f16_f16<NUM_THREADS>(value);
354+ if (tid == 0 ) s_mean = sum / K_;
355+ // wait for s_mean in shared memory to be ready for all threads
356+ __syncthreads ();
357+
358+ half variance = z_;
359+ #pragma unroll
360+ for (int i = 0 ; i < 8 ; ++i) {
361+ half v_hat = pack_x[i] - s_mean;
362+ variance += ((idx + i) < N * K ? v_hat * v_hat : z_);
363+ }
364+ variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
365+ if (tid == 0 ) s_variance = hrsqrt (variance / (K_ + epsilon));
366+ // wait for s_variance in shared memory to be ready for all threads
367+ __syncthreads ();
368+
369+ #pragma unroll
370+ for (int i = 0 ; i < 8 ; ++i) {
371+ // TODO: use __hfma2, __hsub2, __hmul2 here
372+ pack_y[i] = __hfma ((pack_x[i] - s_mean) * s_variance, g_, b_);
373+ }
374+ // reinterpret as float4 and store 128 bits in 1 memory issue.
375+ if ((idx + 7 ) < N * K) { LDST128BITS (y[idx]) = LDST128BITS (pack_y[0 ]); }
376+ // TODO: support non 8-multiple K here
377+ }
378+
328379// --------------------- PyTorch bindings for custom kernel -----------------------
329380#define STRINGFY (str ) #str
330381#define TORCH_BINDING_COMMON_EXTENSION (func ) \
@@ -350,7 +401,7 @@ layer_norm_f32_kernel<(K)><<<grid, block>>>( \
350401
351402#define DISPATCH_LAYER_NORM_F32_KERNEL (N, K ) \
352403 dim3 block ((K)); \
353- dim3 grid ((N)); \
404+ dim3 grid ((N)); \
354405 switch ((K)) \
355406 { \
356407 case 64 : \
@@ -382,7 +433,7 @@ layer_norm_f32x4_kernel<(K)/4><<<grid, block>>>( \
382433
383434#define DISPATCH_LAYER_NORM_F32x4_KERNEL (N, K ) \
384435 dim3 block ((K)/4); \
385- dim3 grid ((N)); \
436+ dim3 grid ((N)); \
386437 switch ((K)) \
387438 { \
388439 case 64 : \
@@ -400,9 +451,15 @@ layer_norm_f32x4_kernel<(K)/4><<<grid, block>>>( \
400451 case 1024 : \
401452 LANUCH_LAYER_NORM_F32x4_KERNEL (1024 ) \
402453 break ; \
454+ case 2048 : \
455+ LANUCH_LAYER_NORM_F32x4_KERNEL (2048 ) \
456+ break ; \
457+ case 4096 : \
458+ LANUCH_LAYER_NORM_F32x4_KERNEL (4096 ) \
459+ break ; \
403460 default : \
404461 throw std::runtime_error ( \
405- " only support K: 64/128/256/512/ 1024" ); \
462+ " only support K: 64/128/.../ 1024*4 " ); \
406463 break ; \
407464 }
408465
@@ -433,7 +490,7 @@ layer_norm_f16_f16_kernel<(K)><<<grid, block>>>( \
433490
434491#define DISPATCH_LAYER_NORM_F16F16_KERNEL (N, K ) \
435492 dim3 block ((K)); \
436- dim3 grid ((N)); \
493+ dim3 grid ((N)); \
437494 switch ((K)) \
438495 { \
439496 case 64 : \
@@ -465,7 +522,7 @@ layer_norm_f16_f32_kernel<(K)><<<grid, block>>>( \
465522
466523#define DISPATCH_LAYER_NORM_F16F32_KERNEL (N, K ) \
467524 dim3 block ((K)); \
468- dim3 grid ((N)); \
525+ dim3 grid ((N)); \
469526 switch ((K)) \
470527 { \
471528 case 64 : \
@@ -497,7 +554,7 @@ layer_norm_f16x2_f16_kernel<(K)/2><<<grid, block>>>( \
497554
498555#define DISPATCH_LAYER_NORM_F16x2F16_KERNEL (N, K ) \
499556 dim3 block ((K)/2); \
500- dim3 grid ((N)); \
557+ dim3 grid ((N)); \
501558 switch ((K)) \
502559 { \
503560 case 64 : \
@@ -515,9 +572,12 @@ layer_norm_f16x2_f16_kernel<(K)/2><<<grid, block>>>( \
515572 case 1024 : \
516573 LANUCH_LAYER_NORM_F16x2F16_KERNEL (1024 ) \
517574 break ; \
575+ case 2048 : \
576+ LANUCH_LAYER_NORM_F16x2F16_KERNEL (2048 ) \
577+ break ; \
518578 default : \
519579 throw std::runtime_error ( \
520- " only support K: 64/128/256/512/ 1024" ); \
580+ " only support K: 64/128/.../ 1024*2 " ); \
521581 break ; \
522582 }
523583
@@ -529,7 +589,7 @@ layer_norm_f16x8_f16_kernel<(K)/8><<<grid, block>>>( \
529589
530590#define DISPATCH_LAYER_NORM_F16x8F16_KERNEL (N, K ) \
531591 dim3 block ((K)/8); \
532- dim3 grid ((N)); \
592+ dim3 grid ((N)); \
533593 switch ((K)) \
534594 { \
535595 case 64 : \
@@ -547,12 +607,62 @@ layer_norm_f16x8_f16_kernel<(K)/8><<<grid, block>>>( \
547607 case 1024 : \
548608 LANUCH_LAYER_NORM_F16x8F16_KERNEL (1024 ) \
549609 break ; \
610+ case 2048 : \
611+ LANUCH_LAYER_NORM_F16x8F16_KERNEL (2048 ) \
612+ break ; \
613+ case 4096 : \
614+ LANUCH_LAYER_NORM_F16x8F16_KERNEL (4096 ) \
615+ break ; \
616+ case 8192 : \
617+ LANUCH_LAYER_NORM_F16x8F16_KERNEL (8192 ) \
618+ break ; \
550619 default : \
551620 throw std::runtime_error ( \
552- " only support K: 64/128/256/512/ 1024" ); \
621+ " only support K: 64/128/.../ 1024*8 " ); \
553622 break ; \
554623 }
555624
625+ #define LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (K ) \
626+ layer_norm_f16x8_pack_f16_kernel<(K)/8 ><<<grid, block>>> ( \
627+ reinterpret_cast <half*>(x.data_ptr()), \
628+ reinterpret_cast <half*>(y.data_ptr()), \
629+ g, b, N, (K));
630+
631+ #define DISPATCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (N, K ) \
632+ dim3 block ((K)/8); \
633+ dim3 grid ((N)); \
634+ switch ((K)) \
635+ { \
636+ case 64 : \
637+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (64 ) \
638+ break ; \
639+ case 128 : \
640+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (128 ) \
641+ break ; \
642+ case 256 : \
643+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (256 ) \
644+ break ; \
645+ case 512 : \
646+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (512 ) \
647+ break ; \
648+ case 1024 : \
649+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (1024 ) \
650+ break ; \
651+ case 2048 : \
652+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (2048 ) \
653+ break ; \
654+ case 4096 : \
655+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (4096 ) \
656+ break ; \
657+ case 8192 : \
658+ LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (8192 ) \
659+ break ; \
660+ default : \
661+ throw std::runtime_error ( \
662+ " only support K: 64/128/.../1024*8" ); \
663+ break ; \
664+ }
665+
556666void layer_norm_f16_f16 (torch::Tensor x, torch::Tensor y, float g, float b) {
557667 CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
558668 CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
@@ -580,6 +690,16 @@ void layer_norm_f16x8_f16(torch::Tensor x, torch::Tensor y, float g, float b) {
580690 DISPATCH_LAYER_NORM_F16x8F16_KERNEL (N, K)
581691}
582692
693+ void layer_norm_f16x8_pack_f16 (torch::Tensor x, torch::Tensor y, float g, float b) {
694+ CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
695+ CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
696+ CHECK_TORCH_TENSOR_SHAPE (x, y)
697+ const int N = x.size (0 );
698+ const int K = x.size (1 );
699+ DISPATCH_LAYER_NORM_F16x8_PACK_F16_KERNEL (N, K)
700+ }
701+
702+
583703void layer_norm_f16_f32 (torch::Tensor x, torch::Tensor y, float g, float b) {
584704 CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
585705 CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
@@ -595,6 +715,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
595715 TORCH_BINDING_COMMON_EXTENSION (layer_norm_f16_f16)
596716 TORCH_BINDING_COMMON_EXTENSION (layer_norm_f16x2_f16)
597717 TORCH_BINDING_COMMON_EXTENSION (layer_norm_f16x8_f16)
718+ TORCH_BINDING_COMMON_EXTENSION (layer_norm_f16x8_pack_f16)
598719 TORCH_BINDING_COMMON_EXTENSION (layer_norm_f16_f32)
599720}
600721
0 commit comments