@@ -47,17 +47,19 @@ def per_channel_quant(x: torch.Tensor, n_bits: int, dtype: torch.dtype):
4747@triton .autotune (
4848 configs = [
4949 triton .Config ({
50- 'BLOCK_N' : 64 ,
50+ 'BLOCK_M' : 128 ,
51+ 'BLOCK_N' : 256 ,
5152 'BLOCK_K' : 128 ,
5253 },
53- num_stages = 4 ,
54- num_warps = 4 ),
54+ num_stages = 3 ,
55+ num_warps = 8 ),
5556 triton .Config ({
57+ 'BLOCK_M' : 256 ,
5658 'BLOCK_N' : 128 ,
5759 'BLOCK_K' : 128 ,
5860 },
59- num_stages = 4 ,
60- num_warps = 4 )
61+ num_stages = 3 ,
62+ num_warps = 8 )
6163 ],
6264 key = ['N' , 'K' ],
6365 warmup = 5 ,
@@ -112,7 +114,7 @@ def _linear(
112114 for k in range (0 , tl .cdiv (K , BLOCK_K )):
113115 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_K , other = None )
114116 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_K , other = None )
115- accumulator + = tl .dot (a , b )
117+ accumulator = tl .dot (a , b , accumulator , out_dtype = ACCUMULATOR_DTYPE )
116118 a_ptrs += BLOCK_K * stride_ak
117119 b_ptrs += BLOCK_K * stride_bk
118120 c = accumulator .to (tl .float32 )
@@ -131,17 +133,19 @@ def _linear(
131133@triton .autotune (
132134 configs = [
133135 triton .Config ({
134- 'BLOCK_N' : 64 ,
136+ 'BLOCK_M' : 128 ,
137+ 'BLOCK_N' : 256 ,
135138 'BLOCK_K' : 128 ,
136139 },
137- num_stages = 4 ,
138- num_warps = 4 ),
140+ num_stages = 3 ,
141+ num_warps = 8 ),
139142 triton .Config ({
143+ 'BLOCK_M' : 256 ,
140144 'BLOCK_N' : 128 ,
141145 'BLOCK_K' : 128 ,
142146 },
143- num_stages = 4 ,
144- num_warps = 4 )
147+ num_stages = 3 ,
148+ num_warps = 8 )
145149 ],
146150 key = ['N' , 'K' ],
147151 warmup = 5 ,
@@ -181,7 +185,7 @@ def _linear_add(A, B, C, residual_ptr, M, N, K, stride_am, stride_ak,
181185 for k in range (0 , tl .cdiv (K , BLOCK_K )):
182186 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_K , other = None )
183187 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_K , other = None )
184- accumulator + = tl .dot (a , b )
188+ accumulator = tl .dot (a , b , accumulator , out_dtype = ACCUMULATOR_DTYPE )
185189 a_ptrs += BLOCK_K * stride_ak
186190 b_ptrs += BLOCK_K * stride_bk
187191 c = accumulator .to (tl .float32 )
@@ -226,13 +230,10 @@ def matmul_kernel_dynamic_quant(a,
226230 assert residual .is_contiguous ()
227231 c = a .new_empty (c_shape , dtype = output_dtype )
228232 accumulator_dtype = tl .float32 if a .is_floating_point () else tl .int32
229- BLOCK_M = 128
230- if M < BLOCK_M :
231- BLOCK_M = triton .next_power_of_2 (M )
232- BLOCK_M = max (BLOCK_M , 16 )
233233
234234 def grid (META ):
235- return (triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , META ['BLOCK_N' ]), )
235+ return (triton .cdiv (M , META ['BLOCK_M' ]) *
236+ triton .cdiv (N , META ['BLOCK_N' ]), )
236237
237238 kernel_meta = get_kernel_meta (a )
238239 if residual is not None :
@@ -249,7 +250,6 @@ def grid(META):
249250 b .stride (0 ),
250251 c .stride (- 2 ),
251252 c .stride (- 1 ),
252- BLOCK_M = BLOCK_M ,
253253 GROUP_SIZE_M = 8 ,
254254 rms_scale_ptr = rms_scale ,
255255 linear_scale_ptr = linear_scale ,
@@ -268,7 +268,6 @@ def grid(META):
268268 b .stride (0 ),
269269 c .stride (- 2 ),
270270 c .stride (- 1 ),
271- BLOCK_M = BLOCK_M ,
272271 GROUP_SIZE_M = 8 ,
273272 rms_scale_ptr = rms_scale ,
274273 linear_scale_ptr = linear_scale ,
@@ -585,26 +584,11 @@ def per_token_quant_int8_torch(x, eps, quant_dtype):
585584 x_s_torch .flatten ().to (torch .float32 )))
586585
587586
588- @triton .testing .perf_report (
589- triton .testing .Benchmark (
590- x_names = ['M' ],
591- x_vals = [1 , 16 , 32 , 64 , 128 , 256 ] + [512 * i * 2 for i in range (1 , 17 )],
592- line_arg = 'provider' ,
593- line_vals = ['int8_dynamic_triton_op' , 'float_torch' ],
594- line_names = ['int8_dynamic_triton_op' , 'float_torch' ],
595- styles = [('blue' , '-' ), ('green' , '-' ), ('orange' , '-' ),
596- ('yellow' , '-' ), ('yellow' , '-' )],
597- ylabel = 'GB/s' ,
598- plot_name = 'forward' ,
599- args = {
600- 'dtype' : torch .float16 ,
601- }))
602- def bench_rms_and_linear (M ,
603- dtype ,
604- provider ,
605- eps = 1e-5 ,
606- device = 'cuda' ,
607- quant_dtype = torch .int8 ):
587+ def bench_rms_and_linear (M : int ,
588+ provider : str ,
589+ dtype : torch .dtype = torch .float16 ,
590+ eps : float = 1e-5 ):
591+ """benchmark rms and linear."""
608592
609593 def rms_norm_torch (x , w , eps ):
610594 variance = x .to (torch .float32 ).pow (2 ).mean (- 1 , keepdim = True )
@@ -616,6 +600,7 @@ def linear_torch(x, b):
616600
617601 N = 4096
618602 K = 4096
603+
619604 x_shape = (M , K )
620605 rms_w_shape = (x_shape [- 1 ], )
621606 rms_weight = torch .randn (rms_w_shape ,
@@ -627,17 +612,29 @@ def linear_torch(x, b):
627612 dtype = dtype ,
628613 device = 'cuda' ,
629614 requires_grad = True )
630- linear_weight_quant , linear_scale = per_channel_quant (
631- linear_weight , 8 , quant_dtype )
632615
633- alpha = max (x .max ().abs (), x .min ().abs ())
634- if quant_dtype .is_floating_point :
635- qdtype_info = torch .finfo (quant_dtype )
636- else :
637- qdtype_info = torch .iinfo (quant_dtype )
638- rms_scale = alpha / qdtype_info .max
616+ if provider == 'torch_fp16' :
617+ rms_out_torch = rms_norm_torch (x , rms_weight , eps ).half ()
639618
640- if provider == 'int8_dynamic_triton_op' :
619+ def y_fwd ():
620+ linear_torch (rms_out_torch , linear_weight )
621+ else :
622+ if provider == 'triton_int8' :
623+ quant_dtype = torch .int8
624+ elif provider == 'triton_fp8_e4m3' :
625+ quant_dtype = torch .float8_e4m3fn
626+ elif provider == 'triton_fp8_e5m2' :
627+ quant_dtype = torch .float8_e5m2
628+
629+ linear_weight_quant , linear_scale = per_channel_quant (
630+ linear_weight , 8 , quant_dtype )
631+
632+ alpha = max (x .max ().abs (), x .min ().abs ())
633+ if quant_dtype .is_floating_point :
634+ qdtype_info = torch .finfo (quant_dtype )
635+ else :
636+ qdtype_info = torch .iinfo (quant_dtype )
637+ rms_scale = alpha / qdtype_info .max
641638 rms_out , rms_scale = rms_norm_dynamic_quant (x ,
642639 rms_weight ,
643640 eps ,
@@ -650,17 +647,16 @@ def y_fwd():
650647 rms_scale ,
651648 linear_scale ,
652649 output_dtype = dtype )
653- elif provider == 'float_torch' :
654- rms_out_torch = rms_norm_torch (x , rms_weight , eps ).half ()
655-
656- def y_fwd ():
657- linear_torch (rms_out_torch , linear_weight )
658650
659651 quantiles = [0.5 , 0.2 , 0.8 ]
660652 ms , min_ms , max_ms = triton .testing .do_bench (y_fwd ,
661653 quantiles = quantiles ,
662654 rep = 500 )
663- return ms , max_ms , min_ms
655+
656+ def perf (ms ):
657+ return 2 * M * N * K * 1e-12 / (ms * 1e-3 )
658+
659+ return perf (ms ), perf (max_ms ), perf (min_ms )
664660
665661
666662if __name__ == '__main__' :
@@ -720,4 +716,26 @@ def y_fwd():
720716 test_per_token_quant (x , eps , quant_dtype = torch .float8_e4m3fn )
721717 test_per_token_quant (x , eps , quant_dtype = torch .float8_e5m2 )
722718
723- bench_rms_and_linear .run (print_data = True )
719+ # benchmark triton kernels
720+ line_vals = ['triton_int8' , 'torch_fp16' ]
721+ line_names = ['triton_int8' , 'torch_fp16' ]
722+
723+ if is_fp8_supported :
724+ line_vals += ['triton_fp8_e4m3' , 'triton_fp8_e5m2' ]
725+ line_names += ['triton_fp8_e4m3' , 'triton_fp8_e5m2' ]
726+ config = triton .testing .Benchmark (x_names = ['M' ],
727+ x_vals = [1 , 16 , 32 , 64 , 128 , 256 ] +
728+ [512 * i * 2 for i in range (1 , 5 )],
729+ line_arg = 'provider' ,
730+ line_vals = line_vals ,
731+ line_names = line_names ,
732+ styles = [('blue' , '-' ), ('green' , '-' ),
733+ ('orange' , '-' ), ('black' , '-' ),
734+ ('yellow' , '-' )],
735+ ylabel = 'TFLOPS' ,
736+ plot_name = 'bench-triton' ,
737+ args = {
738+ 'dtype' : torch .float16 ,
739+ })
740+ bench_funch = (triton .testing .perf_report (config ))(bench_rms_and_linear )
741+ bench_funch .run (print_data = True )
0 commit comments