@@ -876,6 +876,8 @@ def get_act_max_hook(module, input, output):
876876 pbar = tqdm (all_to_quantized_module_names )
877877 block_names_cnt = len (flatten_list (get_block_names (self .model ,True )))
878878 clear_mem_freq = len (all_to_quantized_module_names )// block_names_cnt
879+ if clear_mem_freq == 0 :
880+ clear_mem_freq = 1
879881 cnt = 1
880882 for name in pbar :
881883 pbar .set_description (f"Quantizing { name } " )
@@ -895,26 +897,27 @@ def get_act_max_hook(module, input, output):
895897 model = model .to ("cpu" )
896898 clear_memory ()
897899 self .quantize_via_rtn_blockwise (all_to_quantized_module_names )
898- except Exception :
899- # Final fallback: warn and use CPU-only quantization
900- logger .warning ("Fallback to CPU. "
901- "Consider enabling `low_gpu_mem_usage` or using more GPUs via `--device 0,1,2,3`." )
902- model = model .to ("cpu" )
903- clear_memory ()
904- if hasattr (model , "hf_device_map" ) and len (model .hf_device_map ) > 1 :
905- import accelerate
906- accelerate .hooks .remove_hook_from_submodules (model )
907-
908- orig_device = self .device
909- self .device = "cpu"
910- self .quantize_via_rtn_blockwise (all_to_quantized_module_names )
911- self .device = orig_device
912- finally :
913- # Always remove hooks
914- for hook in hooks :
915- hook .remove ()
916- else :
917- raise
900+ except RuntimeError as e :
901+ if "CUDA out of memory" in str (e ) or "MODULE:PT_DEVMEM" in str (e ):
902+ # Final fallback: warn and use CPU-only quantization
903+ logger .warning ("Fallback to CPU. "
904+ "Consider enabling `low_gpu_mem_usage` or using more GPUs via `--device 0,1,2,3`." )
905+ model = model .to ("cpu" )
906+ clear_memory ()
907+ if hasattr (model , "hf_device_map" ) and len (model .hf_device_map ) > 1 :
908+ import accelerate
909+ accelerate .hooks .remove_hook_from_submodules (model )
910+
911+ orig_device = self .device
912+ self .device = "cpu"
913+ self .quantize_via_rtn_blockwise (all_to_quantized_module_names )
914+ self .device = orig_device
915+ else :
916+ raise
917+ finally :
918+ # Always remove hooks
919+ for hook in hooks :
920+ hook .remove ()
918921
919922 # Move back to CPU and free memory
920923 model .to ("cpu" )
@@ -1119,6 +1122,8 @@ def quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
11191122 else :
11201123 block_names_cnt = len (flatten_list (get_block_names (self .model , True )))
11211124 clear_mem_freq = len (all_to_quantized_module_names ) // block_names_cnt
1125+ if clear_mem_freq == 0 :
1126+ clear_mem_freq = 1
11221127 pbar = tqdm (all_to_quantized_module_names )
11231128 cnt = 1
11241129 for name in pbar :
@@ -1223,6 +1228,8 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) -
12231228 cnt = 1
12241229 block_names_cnt = len (flatten_list (get_block_names (self .model , True )))
12251230 clear_mem_freq = len (all_to_quantized_module_names ) // block_names_cnt
1231+ if clear_mem_freq == 0 :
1232+ clear_mem_freq = 1
12261233 # Process remaining layers not in blocks
12271234 for name in all_to_quantized_module_names :
12281235 self .quantize_layer_via_rtn (name )
0 commit comments