@@ -858,6 +858,218 @@ def test_backward_override_recipe_matches_requested_mode(
858858 assert quant_recipe .backward_override is None
859859
860860
861+ @pytest .mark .parametrize ("recipe_name" , _quantized_numerics_recipe_list )
862+ @pytest .mark .parametrize ("use_bias" , (False , True ), ids = ("no_bias" , "bias" ))
863+ def test_linear_backward_override_dequantized_ignores_save_original_input (
864+ recipe_name : str ,
865+ use_bias : bool ,
866+ ) -> None :
867+ reset_rng_states ()
868+ dtype = torch .bfloat16
869+ input_shape = (32 , 128 )
870+ out_features = 128
871+ _maybe_skip_recipe_dtype (recipe_name , dtype , "linear" )
872+ _maybe_skip_unsupported_recipe_module_combo (recipe_name , "linear" )
873+ _maybe_skip_unsupported_recipe_shape (recipe_name , input_shape , "linear" )
874+
875+ mode_recipe = make_recipe (recipe_name , backward_override = "dequantized" )
876+ skip_unsupported_backward_override ("linear" , mode_recipe , "dequantized" )
877+
878+ module_ref = te .Linear (
879+ input_shape [- 1 ],
880+ out_features ,
881+ bias = use_bias ,
882+ params_dtype = dtype ,
883+ device = "cuda" ,
884+ save_original_input = False ,
885+ )
886+ module_test = te .Linear (
887+ input_shape [- 1 ],
888+ out_features ,
889+ bias = use_bias ,
890+ params_dtype = dtype ,
891+ device = "cuda" ,
892+ save_original_input = True ,
893+ )
894+ _copy_named_parameters (module_ref , module_test )
895+
896+ x = torch .randn (* input_shape , dtype = dtype , device = "cuda" )
897+ dy = torch .randn (input_shape [0 ], out_features , dtype = dtype , device = "cuda" )
898+
899+ y_ref , dx_ref , dw_ref , db_ref = _run_single_step (module_ref , x , dy , mode_recipe )
900+ y_test , x_test , saved_operands = _run_single_step_with_saved_operands (
901+ module_test , x , mode_recipe
902+ )
903+ _assert_saved_quantized_operand_uses_rowwise_only (saved_operands [0 ], name = "linear_input" )
904+
905+ y_test_detached = y_test .detach ().clone ()
906+ y_test .backward (dy )
907+ assert x_test .grad is not None
908+ assert module_test .weight .grad is not None
909+ dx_test = x_test .grad .detach ().clone ()
910+ dw_test = module_test .weight .grad .detach ().clone ()
911+ test_bias = getattr (module_test , "bias" , None )
912+ db_test = (
913+ None if test_bias is None or test_bias .grad is None else test_bias .grad .detach ().clone ()
914+ )
915+
916+ assert_close (y_test_detached , y_ref , rtol = 0 , atol = 0 , check_dtype = True )
917+ assert_close (dx_test , dx_ref , rtol = 0 , atol = 0 , check_dtype = True )
918+ assert_close (dw_test , dw_ref , rtol = 0 , atol = 0 , check_dtype = True )
919+ if use_bias :
920+ assert db_test is not None and db_ref is not None
921+ assert_close (db_test , db_ref , rtol = 0 , atol = 0 , check_dtype = True )
922+
923+
924+ @pytest .mark .parametrize ("recipe_name" , _quantized_numerics_recipe_list )
925+ @pytest .mark .parametrize ("use_bias" , (False , True ), ids = ("no_bias" , "bias" ))
926+ def test_grouped_linear_backward_override_dequantized_ignores_save_original_input (
927+ recipe_name : str ,
928+ use_bias : bool ,
929+ ) -> None :
930+ reset_rng_states ()
931+ dtype = torch .bfloat16
932+ in_features = 128
933+ out_features = 128
934+ m_splits = [64 , 64 ]
935+ num_gemms = len (m_splits )
936+ num_tokens = sum (m_splits )
937+ _maybe_skip_recipe_dtype (recipe_name , dtype , "grouped_linear" )
938+ _maybe_skip_unsupported_recipe_module_combo (recipe_name , "grouped_linear" )
939+ _maybe_skip_unsupported_grouped_splits (recipe_name , m_splits )
940+
941+ mode_recipe = make_recipe (recipe_name , backward_override = "dequantized" )
942+ skip_unsupported_backward_override ("grouped_linear" , mode_recipe , "dequantized" )
943+
944+ module_ref = te .GroupedLinear (
945+ num_gemms ,
946+ in_features ,
947+ out_features ,
948+ bias = use_bias ,
949+ params_dtype = dtype ,
950+ device = "cuda" ,
951+ save_original_input = False ,
952+ )
953+ module_test = te .GroupedLinear (
954+ num_gemms ,
955+ in_features ,
956+ out_features ,
957+ bias = use_bias ,
958+ params_dtype = dtype ,
959+ device = "cuda" ,
960+ save_original_input = True ,
961+ )
962+ _copy_named_parameters (module_ref , module_test )
963+
964+ x = torch .randn (num_tokens , in_features , dtype = dtype , device = "cuda" )
965+ dy = torch .randn (num_tokens , out_features , dtype = dtype , device = "cuda" )
966+
967+ y_ref , dx_ref , dw_ref , db_ref = _run_grouped_linear_single_step (
968+ module_ref , x , m_splits , dy , mode_recipe
969+ )
970+ y_test , x_test , saved_operands = _run_grouped_linear_step_with_saved_operands (
971+ module_test , x , m_splits , mode_recipe
972+ )
973+ saved_inputs = saved_operands [:num_gemms ]
974+ for i , saved_input in enumerate (saved_inputs ):
975+ _assert_saved_quantized_operand_uses_rowwise_only (
976+ saved_input , name = f"grouped_linear_input{ i } "
977+ )
978+
979+ y_test_detached = y_test .detach ().clone ()
980+ y_test .backward (dy )
981+ assert x_test .grad is not None
982+ dx_test = x_test .grad .detach ().clone ()
983+ dw_test = [getattr (module_test , f"weight{ i } " ).grad .detach ().clone () for i in range (num_gemms )]
984+ db_test : list [Optional [torch .Tensor ]] = []
985+ for i in range (num_gemms ):
986+ if use_bias :
987+ db_test .append (getattr (module_test , f"bias{ i } " ).grad .detach ().clone ())
988+ else :
989+ db_test .append (None )
990+
991+ assert_close (y_test_detached , y_ref , rtol = 0 , atol = 0 , check_dtype = True )
992+ assert_close (dx_test , dx_ref , rtol = 0 , atol = 0 , check_dtype = True )
993+ for test_dw , ref_dw in zip (dw_test , dw_ref ):
994+ assert_close (test_dw , ref_dw , rtol = 0 , atol = 0 , check_dtype = True )
995+ if use_bias :
996+ for test_db , ref_db in zip (db_test , db_ref ):
997+ assert test_db is not None and ref_db is not None
998+ assert_close (test_db , ref_db , rtol = 0 , atol = 0 , check_dtype = True )
999+
1000+
1001+ @pytest .mark .parametrize ("recipe_name" , _quantized_numerics_recipe_list )
1002+ def test_linear_backward_override_high_precision_forces_save_original_input (
1003+ recipe_name : str ,
1004+ ) -> None :
1005+ reset_rng_states ()
1006+ dtype = torch .bfloat16
1007+ input_shape = (32 , 128 )
1008+ _maybe_skip_recipe_dtype (recipe_name , dtype , "linear" )
1009+ _maybe_skip_unsupported_recipe_module_combo (recipe_name , "linear" )
1010+ _maybe_skip_unsupported_recipe_shape (recipe_name , input_shape , "linear" )
1011+
1012+ mode_recipe = make_recipe (recipe_name , backward_override = "high_precision" )
1013+ skip_unsupported_backward_override ("linear" , mode_recipe , "high_precision" )
1014+
1015+ module = te .Linear (
1016+ input_shape [- 1 ],
1017+ 128 ,
1018+ bias = False ,
1019+ params_dtype = dtype ,
1020+ device = "cuda" ,
1021+ save_original_input = False ,
1022+ )
1023+ x = torch .randn (* input_shape , dtype = dtype , device = "cuda" )
1024+
1025+ _ , _ , saved_operands = _run_single_step_with_saved_operands (module , x , mode_recipe )
1026+
1027+ assert isinstance (saved_operands [0 ], torch .Tensor )
1028+
1029+
1030+ @pytest .mark .parametrize ("recipe_name" , _quantized_numerics_recipe_list )
1031+ def test_grouped_linear_backward_override_high_precision_forces_save_original_input (
1032+ recipe_name : str ,
1033+ ) -> None :
1034+ reset_rng_states ()
1035+ dtype = torch .bfloat16
1036+ in_features = 128
1037+ out_features = 128
1038+ m_splits = [64 , 64 ]
1039+ num_gemms = len (m_splits )
1040+ num_tokens = sum (m_splits )
1041+ _maybe_skip_recipe_dtype (recipe_name , dtype , "grouped_linear" )
1042+ _maybe_skip_unsupported_recipe_module_combo (recipe_name , "grouped_linear" )
1043+ _maybe_skip_unsupported_grouped_splits (recipe_name , m_splits )
1044+
1045+ mode_recipe = make_recipe (recipe_name , backward_override = "high_precision" )
1046+ skip_unsupported_backward_override ("grouped_linear" , mode_recipe , "high_precision" )
1047+
1048+ module = te .GroupedLinear (
1049+ num_gemms ,
1050+ in_features ,
1051+ out_features ,
1052+ bias = False ,
1053+ params_dtype = dtype ,
1054+ device = "cuda" ,
1055+ save_original_input = False ,
1056+ )
1057+ x = torch .randn (num_tokens , in_features , dtype = dtype , device = "cuda" )
1058+
1059+ _ , _ , saved_operands = _run_grouped_linear_step_with_saved_operands (
1060+ module , x , m_splits , mode_recipe
1061+ )
1062+
1063+ saved_inputs = saved_operands [:num_gemms ]
1064+ assert isinstance (saved_inputs [0 ], torch .Tensor )
1065+ assert saved_inputs [0 ].shape == x .shape
1066+ assert all (saved_input is None for saved_input in saved_inputs [1 :])
1067+
1068+ saved_weights = saved_operands [2 * num_gemms : 3 * num_gemms ]
1069+ for saved_weight in saved_weights :
1070+ assert isinstance (saved_weight , torch .Tensor )
1071+
1072+
8611073@pytest .mark .parametrize ("recipe_name" , _quantized_numerics_recipe_list )
8621074@pytest .mark .parametrize ("module_type" , ("linear" , "layernorm_linear" , "ops_linear" ))
8631075@pytest .mark .parametrize ("input_shape,out_features" , _shape_test_cases )
0 commit comments