1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import json
3
4
from typing import Any , Optional
4
5
5
6
import torch
@@ -40,7 +41,8 @@ class TorchAOConfig(QuantizationConfig):
40
41
41
42
def __init__ (self ,
42
43
torchao_config ,
43
- skip_modules : Optional [list [str ]] = None ) -> None :
44
+ skip_modules : Optional [list [str ]] = None ,
45
+ is_checkpoint_torchao_serialized : bool = False ) -> None :
44
46
"""
45
47
# TorchAO quantization relies on tensor subclasses. In order,
46
48
# to enable proper caching this needs standalone compile
@@ -58,9 +60,11 @@ def __init__(self,
58
60
super ().__init__ ()
59
61
self .torchao_config = torchao_config
60
62
self .skip_modules = skip_modules or []
63
+ self .is_checkpoint_torchao_serialized = is_checkpoint_torchao_serialized
61
64
62
65
def __repr__ (self ) -> str :
63
- return f"TorchAOConfig({ self .torchao_config } )"
66
+ return f"TorchAOConfig({ self .torchao_config = } , { self .skip_modules = } , " \
67
+ f"{ self .is_checkpoint_torchao_serialized = } )"
64
68
65
69
def get_name (self ) -> QuantizationMethods :
66
70
return "torchao"
@@ -74,7 +78,10 @@ def get_min_capability(cls) -> int:
74
78
75
79
@staticmethod
76
80
def get_config_filenames () -> list [str ]:
77
- return ["config.json" ]
81
+ """torchao doesn't require additional config files, we use
82
+ `config.json` from huggingface: `model_config.hf_config`
83
+ """
84
+ return []
78
85
79
86
@classmethod
80
87
def from_config (cls , config : dict [str , Any ]) -> "TorchAOConfig" :
@@ -87,6 +94,10 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
87
94
"`pip install torchao>=0.10.0` to use torchao quantization."
88
95
) from err
89
96
97
+ quant_method = cls .get_from_keys_or (config , ["quant_method" ], None )
98
+ is_checkpoint_torchao_serialized = (quant_method is not None
99
+ and "torchao" in quant_method )
100
+
90
101
hf_config = cls .get_from_keys_or (config , ["quant_type" ], None )
91
102
assert hf_config is not None , "quant_type must be specified"
92
103
assert len (hf_config ) == 1 and "default" in hf_config , (
@@ -110,7 +121,38 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
110
121
if layer_cfg is None :
111
122
skip_modules .append (layer )
112
123
113
- return cls (ao_config , skip_modules )
124
+ return cls (ao_config , skip_modules , is_checkpoint_torchao_serialized )
125
+
126
+ @classmethod
127
+ def from_config_file (cls , config_file : str ) -> "TorchAOConfig" :
128
+ """Initialize class from a config file. Example:
129
+ ```
130
+ config = (
131
+ Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
132
+ )
133
+ fn = "torchao_config.json"
134
+
135
+ with open(fn, "w") as f:
136
+ f.write(json.dumps(config_to_dict(config)))
137
+ ```
138
+ """
139
+ with open (config_file ) as f :
140
+ f .seek (0 )
141
+ f_read = f .read ()
142
+ config_dict = json .loads (f_read )
143
+
144
+ hf_config = {"quant_type" : {"default" : config_dict }}
145
+ return cls .from_config (hf_config )
146
+
147
+ @classmethod
148
+ def from_config_dict_json (cls , config_dict_json : str ) -> "TorchAOConfig" :
149
+ """Iniitalize class from a config_dict json string, got from
150
+ torchao_config_object = some AOBaseConfig object
151
+ json.dumps(config_to_dict(torchao_config_object))
152
+ """
153
+ config_dict = json .loads (config_dict_json )
154
+ hf_config = {"quant_type" : {"default" : config_dict }}
155
+ return cls .from_config (hf_config )
114
156
115
157
def get_quant_method (self , layer : torch .nn .Module ,
116
158
prefix : str ) -> Optional ["QuantizeMethodBase" ]:
@@ -128,7 +170,9 @@ def get_quant_method(self, layer: torch.nn.Module,
128
170
c = module_fqn_to_config .get (
129
171
module_fqn ) or module_fqn_to_config .get ("_default" , None )
130
172
if c is not None :
131
- current_torchao_config = TorchAOConfig (c , self .skip_modules )
173
+ current_torchao_config = TorchAOConfig (
174
+ c , self .skip_modules ,
175
+ self .is_checkpoint_torchao_serialized )
132
176
return TorchAOLinearMethod (current_torchao_config )
133
177
else :
134
178
return UnquantizedLinearMethod ()
@@ -172,7 +216,7 @@ class TorchAOLinearMethod(LinearMethodBase):
172
216
"""Linear method for torchao.
173
217
174
218
Args:
175
- quant_config: The torchao quantization config, a string that encodes
219
+ quant_config: The torchao quantization config, a string that encodes
176
220
the type of quantization and all relevant arguments.
177
221
"""
178
222
@@ -197,8 +241,9 @@ def create_weights(
197
241
),
198
242
requires_grad = False ,
199
243
)
200
- weight = torchao_quantize_param_data (weight ,
201
- self .quant_config .torchao_config )
244
+ if self .quant_config .is_checkpoint_torchao_serialized :
245
+ weight = torchao_quantize_param_data (
246
+ weight , self .quant_config .torchao_config )
202
247
203
248
set_weight_attrs (weight , {"input_dim" : 1 , "output_dim" : 0 })
204
249
@@ -212,3 +257,14 @@ def apply(
212
257
bias : Optional [torch .Tensor ] = None ,
213
258
) -> torch .Tensor :
214
259
return F .linear (x , layer .weight , bias )
260
+
261
+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
262
+ if self .quant_config .is_checkpoint_torchao_serialized :
263
+ return
264
+
265
+ # quantize the weight on the fly if the checkpoint is not already
266
+ # quantized by torchao
267
+ weight = torchao_quantize_param_data (layer .weight ,
268
+ self .quant_config .torchao_config )
269
+ set_weight_attrs (weight , {"input_dim" : 1 , "output_dim" : 0 })
270
+ layer .register_parameter ("weight" , weight )
0 commit comments