11import os
22import shutil
3+ from collections import defaultdict
34from concurrent .futures import ThreadPoolExecutor , as_completed
45from pathlib import Path
56from typing import Optional
1314
1415from llmcompressor .entrypoints .model_free .helpers import (
1516 gpu_if_available ,
17+ validate_safetensors_index ,
1618 validate_scheme ,
1719)
1820from llmcompressor .entrypoints .model_free .lifecycle import (
19- calibrate_weights ,
21+ calibrate_global_scale ,
22+ calibrate_scale_zp ,
2023 compress_module ,
2124 initialize_quantized_linear ,
2225)
26+ from llmcompressor .entrypoints .model_free .microscale import (
27+ get_fused_names ,
28+ is_microscale_scheme ,
29+ )
2330from llmcompressor .entrypoints .model_free .model_utils import (
2431 get_checkpoint_files ,
2532 is_weights_file ,
@@ -55,16 +62,20 @@ def model_free_ptq(
5562 model_files = get_checkpoint_files (model_stub )
5663 scheme_name , scheme = validate_scheme (scheme )
5764 device = gpu_if_available (device )
65+ validate_safetensors_index (model_files , scheme )
5866
5967 # 0. collect safetensors files, copy files
6068 jobs = []
69+ job_fn = (
70+ _process_file
71+ if not is_microscale_scheme (scheme )
72+ else _process_file_microscale_scheme
73+ )
6174 for file_path , resolved_path in model_files :
6275 save_path = Path (save_directory ) / file_path
6376
6477 if file_path .endswith ("safetensors" ):
65- jobs .append (
66- (_process_file , resolved_path , save_path , scheme , ignore , device )
67- )
78+ jobs .append ((job_fn , resolved_path , save_path , scheme , ignore , device ))
6879
6980 else :
7081 if is_weights_file (file_path ):
@@ -108,6 +119,7 @@ def _process_file(
108119 ignored
109120 :param device: device used to quantize and compress weights
110121 """
122+ assert not is_microscale_scheme (scheme ), "Use `_process_file_microscale_scheme`"
111123 tensors = load_file (file_path )
112124
113125 for name in list (tensors .keys ()):
@@ -121,7 +133,66 @@ def _process_file(
121133 module = initialize_quantized_linear (tensors [name ], scheme , device )
122134
123135 # 2. calibrate weight qparams
124- calibrate_weights (module )
136+ calibrate_scale_zp (module )
137+
138+ # 3. compress module using qparams
139+ compress_module (module )
140+
141+ # 4. save compressed data (on cpu)
142+ del tensors [name ]
143+ prefix = module_name + "."
144+ for key , value in module .state_dict (prefix = prefix ).items ():
145+ tensors [key ] = value .to ("cpu" )
146+
147+ save_file (tensors , save_path )
148+ total_size = sum (tensor .nbytes for tensor in tensors .values ())
149+ weight_map = {key : os .path .basename (save_path ) for key in tensors .keys ()}
150+ return total_size , weight_map
151+
152+
153+ def _process_file_microscale_scheme (
154+ file_path : str | os .PathLike ,
155+ save_path : str | os .PathLike ,
156+ scheme : QuantizationScheme ,
157+ ignore : str | list [str ],
158+ device : str | torch .device ,
159+ ) -> tuple [int , dict [str , str ]]:
160+ """
161+ Quantize and compress tensors in a given safetensors file
162+
163+ :param file_path: safetensors file to process
164+ :param save_path: save path of file with quantized weights
165+ :param scheme: quantization scheme to apply to tensors
166+ :param ignore: modules to ignore. Modules ending with "norm" are automatically
167+ ignored
168+ :param device: device used to quantize and compress weights
169+ """
170+ assert is_microscale_scheme (scheme ), "Use `_process_file` for non microscale scheme"
171+ tensors = load_file (file_path )
172+ fused_names = get_fused_names (tensors )
173+ fused_names_to_parent = {
174+ name : prefix for prefix , names in fused_names .items () for name in names
175+ }
176+ fused_parent_submodules = defaultdict (dict )
177+
178+ for name in list (tensors .keys ()):
179+ module_name , param_name = name .rsplit ("." , 1 )
180+ is_linear_weight = param_name == "weight" and not module_name .endswith ("norm" )
181+ is_ignored = any (_match_name (module_name , ign ) for ign in ignore )
182+ if not is_linear_weight or is_ignored :
183+ continue
184+
185+ # 1. initialize module with qparams (on device)
186+ module = initialize_quantized_linear (tensors [name ], scheme , device )
187+
188+ # 2. calibrate weight qparams. Delay scale/zp calibration for fused modules
189+ calibrate_global_scale (module )
190+ if name in fused_names_to_parent :
191+ fused_parent = fused_names_to_parent [name ]
192+ fused_parent_submodules [fused_parent ][name ] = module
193+ continue
194+
195+ calibrate_scale_zp (module )
125196
126197 # 3. compress module using qparams
127198 compress_module (module )
@@ -132,6 +203,28 @@ def _process_file(
132203 for key , value in module .state_dict (prefix = prefix ).items ():
133204 tensors [key ] = value .to ("cpu" )
134205
206+ # compress and save miscroscale fused modules
207+ for parent_name , named_modules in fused_parent_submodules .items ():
208+ # 2.1. fuse global scales
209+ global_scales = [m .weight_global_scale for m in named_modules .values ()]
210+ fused_global_scale = torch .min (torch .cat (global_scales , dim = 0 ))
211+
212+ for name , module in named_modules .items ():
213+ module_name , param_name = name .rsplit ("." , 1 )
214+ module .weight_global_scale .data .copy_ (fused_global_scale )
215+
216+ # 2.2. finish calibration with fused global scales
217+ calibrate_scale_zp (module )
218+
219+ # 3. compress module using qparams
220+ compress_module (module )
221+
222+ # 4. save compressed data (on cpu)
223+ del tensors [name ]
224+ prefix = module_name + "."
225+ for key , value in module .state_dict (prefix = prefix ).items ():
226+ tensors [key ] = value .to ("cpu" )
227+
135228 save_file (tensors , save_path )
136229 total_size = sum (tensor .nbytes for tensor in tensors .values ())
137230 weight_map = {key : os .path .basename (save_path ) for key in tensors .keys ()}
0 commit comments