@@ -60,8 +60,10 @@ def per_channel_quant(x: torch.Tensor, n_bits: int, dtype: torch.dtype):
6060 num_warps = 4 )
6161 ],
6262 key = ['N' , 'K' ],
63+ warmup = 5 ,
64+ rep = 20 ,
6365)
64- @triton .jit
66+ @triton .jit ( do_not_specialize = [ 'M' ])
6567def _linear (
6668 A ,
6769 B ,
@@ -142,8 +144,10 @@ def _linear(
142144 num_warps = 4 )
143145 ],
144146 key = ['N' , 'K' ],
147+ warmup = 5 ,
148+ rep = 20 ,
145149)
146- @triton .jit
150+ @triton .jit ( do_not_specialize = [ 'M' ])
147151def _linear_add (A , B , C , residual_ptr , M , N , K , stride_am , stride_ak ,
148152 stride_bk , stride_bn , stride_cm , stride_cn ,
149153 BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr ,
@@ -281,7 +285,8 @@ def _per_token_quant_int8(
281285 y_ptr ,
282286 y_q_ptr ,
283287 y_s_ptr ,
284- y_stride ,
288+ y_stride : tl .constexpr ,
289+ yq_stride : tl .constexpr ,
285290 N , # number of columns in X
286291 eps : tl .constexpr , # epsilon to avoid division by zero
287292 BLOCK : tl .constexpr ,
@@ -296,7 +301,7 @@ def _per_token_quant_int8(
296301 # Map the program id to the row of X and Y it should compute.
297302 row = tl .program_id (0 )
298303 y_ptr += row * y_stride
299- y_q_ptr += row * y_stride
304+ y_q_ptr += row * yq_stride
300305 y_s_ptr += row
301306
302307 cols = tl .arange (0 , BLOCK ) # N <= BLOCK
@@ -333,15 +338,20 @@ def per_token_quant_int8(x, eps, quant_dtype=torch.int8):
333338 BLOCK = triton .next_power_of_2 (N )
334339 # heuristics for number of warps
335340 num_warps = min (max (BLOCK // 256 , 1 ), 8 )
341+
342+ if x .dim () > 2 :
343+ x = x .flatten (0 , - 2 )
344+ assert x .stride (- 1 ) == 1
336345 # enqueue kernel
337346 kernel_meta = get_kernel_meta (x )
338347 _per_token_quant_int8 [(M , )](
339348 x ,
340349 x_q ,
341350 x_s ,
342- x .stride (- 2 ),
343- N ,
344- eps ,
351+ y_stride = x .stride (- 2 ),
352+ yq_stride = x_q .stride (- 2 ),
353+ N = N ,
354+ eps = eps ,
345355 BLOCK = BLOCK ,
346356 Q_MAX = q_max ,
347357 IS_FLOATING_POINT = quant_dtype .is_floating_point ,
@@ -352,46 +362,98 @@ def per_token_quant_int8(x, eps, quant_dtype=torch.int8):
352362
353363
354364@triton .jit
355- def _rms_norm_fwd_fused_dynamic_symmetric (
356- X , # pointer to the input
357- Y , # pointer to the output
358- W , # pointer to the weights
359- Scale , # pointer to the scales of the output activation
360- stride , # how much to increase the pointer when moving by 1 row
361- N , # number of columns in X
362- eps : tl .constexpr , # epsilon to avoid division by zero
363- BLOCK_SIZE : tl .constexpr ,
365+ def _compute_rms_norm (x , w , eps : tl .constexpr , N_COLS : tl .constexpr ):
366+ """compute rms norm."""
367+ xf = x .to (tl .float32 )
368+
369+ var = tl .sum (xf * xf , 0 ) * float (1.0 / N_COLS )
370+ out = xf * tl .math .rsqrt (var + eps )
371+ out = (w * out ).to (x .dtype )
372+ return out
373+
374+
375+ @triton .jit
376+ def rms_norm_quant_kernel (
377+ input ,
378+ weight ,
379+ output ,
380+ out_scale ,
381+ input_row_stride : tl .constexpr ,
382+ eps : tl .constexpr ,
383+ N_COLS : tl .constexpr ,
384+ BLOCK_N : tl .constexpr ,
364385 Q_MIN : tl .constexpr ,
365386 Q_MAX : tl .constexpr ,
366387 IS_FLOATING_POINT : tl .constexpr ,
367388):
368- """A Triton kernel that calculates Root Mean Square (RMS) normalization
369- with fused dynamic symmetric quantization."""
370- row = tl .program_id (0 )
371- Y += row * stride
372- X += row * stride
389+ """rms norm kernel."""
390+ prog_id = tl .program_id (0 )
391+ offsets = tl .arange (0 , BLOCK_N )
373392
374- cols = tl .arange (0 , BLOCK_SIZE )
375- mask = cols < N
376- x = tl .load (X + cols , mask = mask , other = 0. ).to (tl .float32 )
377- _var = x * x
378- var = tl .sum (_var , axis = 0 ) / N
379- rstd = tl .math .rsqrt (var + eps )
380-
381- w = tl .load (W + cols , mask = mask )
382- x_hat = x * rstd
383- y = x_hat * w
384-
385- scale = tl .max (tl .abs (y )).to (tl .float32 ) / Q_MAX
386- tl .store (Scale + row , scale )
387- y = y / scale
393+ w = tl .load (weight + offsets , mask = offsets < N_COLS )
394+
395+ x_ptr = input + prog_id * input_row_stride
396+ x = tl .load (x_ptr + offsets , mask = offsets < N_COLS )
397+ out = _compute_rms_norm (x , w , eps , N_COLS )
398+
399+ scale = tl .max (tl .abs (out )).to (tl .float32 ) / Q_MAX
400+ out_s_ptr = out_scale + prog_id
401+ tl .store (out_s_ptr , scale )
402+ out = out / scale
403+ if not IS_FLOATING_POINT :
404+ out = tl_round (out )
405+ out = tl .clamp (out , Q_MIN , Q_MAX )
406+ out_ptr = output + prog_id * input_row_stride
407+ tl .store (out_ptr + offsets , out , mask = offsets < N_COLS )
408+
409+
410+ @triton .jit
411+ def add_rms_norm_quant_kernel (
412+ input ,
413+ weight ,
414+ residual ,
415+ output ,
416+ out_scale ,
417+ out_residual ,
418+ input_row_stride : tl .constexpr ,
419+ residual_row_stride : tl .constexpr ,
420+ eps : tl .constexpr ,
421+ N_COLS : tl .constexpr ,
422+ BLOCK_N : tl .constexpr ,
423+ Q_MIN : tl .constexpr ,
424+ Q_MAX : tl .constexpr ,
425+ IS_FLOATING_POINT : tl .constexpr ,
426+ ):
427+ """rms norm kernel."""
428+ prog_id = tl .program_id (0 )
429+ offsets = tl .arange (0 , BLOCK_N )
430+
431+ w = tl .load (weight + offsets , mask = offsets < N_COLS )
432+
433+ x_ptr = input + prog_id * input_row_stride
434+ x = tl .load (x_ptr + offsets , mask = offsets < N_COLS )
435+
436+ res_ptr = residual + prog_id * residual_row_stride
437+ res = tl .load (res_ptr + offsets , mask = offsets < N_COLS )
438+
439+ new_x = x + res
440+ out_res_ptr = out_residual + prog_id * residual_row_stride
441+ tl .store (out_res_ptr + offsets , new_x , mask = offsets < N_COLS )
442+
443+ out = _compute_rms_norm (new_x , w , eps , N_COLS )
444+
445+ scale = tl .max (tl .abs (out )).to (tl .float32 ) / Q_MAX
446+ out_s_ptr = out_scale + prog_id
447+ tl .store (out_s_ptr , scale )
448+ out = out / scale
388449 if not IS_FLOATING_POINT :
389- y = tl_round (y )
390- y = tl .clamp (y , Q_MIN , Q_MAX )
391- tl .store (Y + cols , y , mask = mask )
450+ out = tl_round (out )
451+ out = tl .clamp (out , Q_MIN , Q_MAX )
452+ out_ptr = output + prog_id * input_row_stride
453+ tl .store (out_ptr + offsets , out , mask = offsets < N_COLS )
392454
393455
394- def rms_norm_dynamic_quant (x , w , eps , quant_dtype = torch .int8 ):
456+ def rms_norm_dynamic_quant (x , w , eps , residual = None , quant_dtype = torch .int8 ):
395457 """Performs RMS normalization with dynamic quantization.
396458
397459 The function reshapes the input tensor `x`, creates an empty tensor `y`
@@ -401,32 +463,52 @@ def rms_norm_dynamic_quant(x, w, eps, quant_dtype=torch.int8):
401463 qdtype_info = torch .finfo (
402464 quant_dtype ) if quant_dtype .is_floating_point else torch .iinfo (
403465 quant_dtype )
404- x_arg = x .flatten (0 , - 2 )
405466 y = torch .empty_like (x , dtype = quant_dtype )
406- M , K = x_arg .shape
407- MAX_FUSED_SIZE = 65536 // x .element_size ()
408- BLOCK_SIZE = min (MAX_FUSED_SIZE , triton .next_power_of_2 (K ))
409- if K > BLOCK_SIZE :
410- raise RuntimeError (
411- "This rms norm doesn't support feature dim >= 64KB." )
412- num_warps = min (max (BLOCK_SIZE // 256 , 1 ), 8 )
413467 scale = x .new_empty (x .shape [:- 1 ] + (1 , ), dtype = torch .float32 )
414- kernel_meta = get_kernel_meta (x_arg )
415- _rms_norm_fwd_fused_dynamic_symmetric [(M , )](
416- x_arg ,
417- y ,
418- w ,
419- scale ,
420- x_arg .stride (0 ),
421- K ,
422- eps ,
423- BLOCK_SIZE = BLOCK_SIZE ,
424- Q_MIN = qdtype_info .min ,
425- Q_MAX = qdtype_info .max ,
426- IS_FLOATING_POINT = quant_dtype .is_floating_point ,
427- num_warps = num_warps ,
428- ** kernel_meta )
429- return y , scale
468+
469+ feat_size = w .shape [0 ]
470+ seq_len = x .numel () // x .size (- 1 )
471+ input_stride = x .stride (- 2 )
472+ BLOCK_N = triton .next_power_of_2 (feat_size )
473+ grid = (seq_len , )
474+
475+ if residual is None :
476+ rms_norm_quant_kernel [grid ](
477+ x ,
478+ w ,
479+ y ,
480+ scale ,
481+ input_row_stride = input_stride ,
482+ eps = eps ,
483+ N_COLS = feat_size ,
484+ BLOCK_N = BLOCK_N ,
485+ Q_MIN = qdtype_info .min ,
486+ Q_MAX = qdtype_info .max ,
487+ IS_FLOATING_POINT = quant_dtype .is_floating_point ,
488+ num_warps = 4 ,
489+ num_stages = 2 )
490+ return y , scale
491+ else :
492+ out_residual = torch .empty_like (x )
493+ res_stride = residual .stride (- 2 )
494+ add_rms_norm_quant_kernel [grid ](
495+ x ,
496+ w ,
497+ residual ,
498+ y ,
499+ scale ,
500+ out_residual ,
501+ input_row_stride = input_stride ,
502+ residual_row_stride = res_stride ,
503+ eps = eps ,
504+ N_COLS = feat_size ,
505+ BLOCK_N = BLOCK_N ,
506+ Q_MIN = qdtype_info .min ,
507+ Q_MAX = qdtype_info .max ,
508+ IS_FLOATING_POINT = quant_dtype .is_floating_point ,
509+ num_warps = 4 ,
510+ num_stages = 2 )
511+ return y , scale , out_residual
430512
431513
432514def test_rms_and_linear (x ,
0 commit comments