@@ -493,6 +493,126 @@ def mean_batch_invariant(input,
493
493
return result
494
494
495
495
496
+ @triton .jit
497
+ def _rms_norm_kernel (
498
+ input_ptr ,
499
+ weight_ptr ,
500
+ output_ptr ,
501
+ input_row_stride ,
502
+ output_row_stride ,
503
+ n_cols ,
504
+ eps ,
505
+ BLOCK_SIZE : tl .constexpr ,
506
+ ):
507
+ """
508
+ Compute RMS normalization along the last dimension of a 2D tensor.
509
+ RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
510
+ Each block handles one row of the input tensor.
511
+ """
512
+ row_idx = tl .program_id (0 ).to (tl .int64 )
513
+ row_start_ptr = input_ptr + row_idx * input_row_stride
514
+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
515
+
516
+ # Step 1: Compute sum of squares
517
+ sum_sq = 0.0
518
+ for col_offset in range (0 , n_cols , BLOCK_SIZE ):
519
+ col_idx = col_offset + tl .arange (0 , BLOCK_SIZE )
520
+ mask = col_idx < n_cols
521
+
522
+ vals = tl .load (row_start_ptr + col_idx , mask = mask , other = 0.0 )
523
+ sq_vals = vals * vals
524
+ sum_sq += tl .sum (tl .where (mask , sq_vals , 0.0 ))
525
+
526
+ # Step 2: Compute RMS (root mean square)
527
+ mean_sq = sum_sq / n_cols
528
+ rms = tl .sqrt (mean_sq + eps )
529
+ inv_rms = 1.0 / rms
530
+
531
+ # Step 3: Normalize and apply weight
532
+ for col_offset in range (0 , n_cols , BLOCK_SIZE ):
533
+ col_idx = col_offset + tl .arange (0 , BLOCK_SIZE )
534
+ mask = col_idx < n_cols
535
+ vals = tl .load (row_start_ptr + col_idx , mask = mask , other = 0.0 )
536
+ weight = tl .load (weight_ptr + col_idx , mask = mask , other = 1.0 )
537
+ output = vals * inv_rms * weight
538
+ tl .store (output_row_start_ptr + col_idx , output , mask = mask )
539
+
540
+
541
+ def rms_norm (input : torch .Tensor ,
542
+ weight : torch .Tensor ,
543
+ eps : float = 1e-6 ) -> torch .Tensor :
544
+ """
545
+ Compute RMS normalization using Triton kernel.
546
+
547
+ RMS Norm normalizes the input by the root mean square and scales by weight:
548
+ output = input / sqrt(mean(input^2) + eps) * weight
549
+
550
+ Args:
551
+ input: Input tensor of shape (..., hidden_size)
552
+ weight: Weight tensor of shape (hidden_size,)
553
+ eps: Small constant for numerical stability
554
+
555
+ Returns:
556
+ Tensor with RMS normalization applied along the last dimension
557
+ """
558
+ assert input .is_cuda , "Input must be a CUDA tensor"
559
+ assert weight .is_cuda , "Weight must be a CUDA tensor"
560
+ assert weight .dim () == 1 , "Weight must be 1-dimensional"
561
+ assert input .shape [- 1 ] == weight .shape [0 ], (
562
+ f"Input last dimension ({ input .shape [- 1 ]} ) must match "
563
+ f"weight dimension ({ weight .shape [0 ]} )" )
564
+
565
+ # Flatten all dimensions except the last one
566
+ original_shape = input .shape
567
+ input_2d = input .reshape (- 1 , input .shape [- 1 ])
568
+ input_2d = input_2d .contiguous ()
569
+ weight = weight .contiguous ()
570
+
571
+ n_rows , n_cols = input_2d .shape
572
+
573
+ output = torch .empty_like (input_2d )
574
+ BLOCK_SIZE = 1024
575
+ grid = (n_rows , )
576
+ _rms_norm_kernel [grid ](
577
+ input_2d ,
578
+ weight ,
579
+ output ,
580
+ input_2d .stride (0 ),
581
+ output .stride (0 ),
582
+ n_cols ,
583
+ eps ,
584
+ BLOCK_SIZE = BLOCK_SIZE ,
585
+ )
586
+ return output .reshape (original_shape )
587
+
588
+
589
+ def rms_norm_batch_invariant (input : torch .Tensor ,
590
+ weight : torch .Tensor ,
591
+ eps : float = 1e-6 ) -> torch .Tensor :
592
+ """
593
+ Batch-invariant wrapper for RMS normalization.
594
+
595
+ This function provides a deterministic, batch-invariant implementation
596
+ of RMS normalization for use with the batch_invariant mode.
597
+
598
+ Args:
599
+ input: Input tensor of shape (..., hidden_size)
600
+ weight: Weight tensor of shape (hidden_size,)
601
+ eps: Small constant for numerical stability
602
+
603
+ Returns:
604
+ RMS normalized tensor
605
+ """
606
+ return rms_norm (input , weight , eps = eps )
607
+
608
+
609
+ def linear_batch_invariant (input , weight , bias = None ):
610
+ output = torch .mm (input , weight .t ())
611
+ if bias is not None :
612
+ output = output + bias
613
+ return output
614
+
615
+
496
616
_batch_invariant_MODE = False
497
617
_batch_invariant_LIB = None
498
618
@@ -510,6 +630,7 @@ def enable_batch_invariant_mode():
510
630
_batch_invariant_LIB = torch .library .Library ("aten" , "IMPL" )
511
631
_batch_invariant_LIB .impl ("aten::mm" , mm_batch_invariant , "CUDA" )
512
632
_batch_invariant_LIB .impl ("aten::addmm" , addmm_batch_invariant , "CUDA" )
633
+ _batch_invariant_LIB .impl ("aten::linear" , linear_batch_invariant , "CUDA" )
513
634
_batch_invariant_LIB .impl ("aten::_log_softmax" ,
514
635
_log_softmax_batch_invariant , "CUDA" )
515
636
_batch_invariant_LIB .impl ("aten::mean.dim" , mean_batch_invariant , "CUDA" )
0 commit comments