Skip to content

Commit c312468

Browse files
authored
Support RL online quantization with torchao (#23014)
Signed-off-by: Jerry Zhang <[email protected]>
1 parent 4134312 commit c312468

File tree

6 files changed

+465
-16
lines changed

6 files changed

+465
-16
lines changed

tests/quantization/test_torchao.py

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def test_pre_quantized_model(vllm_runner):
2020
output = llm.generate_greedy(["The capital of France is"],
2121
max_tokens=32)
2222
assert output
23-
print(output)
2423

2524

2625
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@@ -42,7 +41,6 @@ def test_opt_125m_int8wo_model_loading_with_params(vllm_runner,
4241
max_tokens=32)
4342

4443
assert output
45-
print(output)
4644

4745

4846
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@@ -57,7 +55,6 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
5755
max_tokens=32)
5856

5957
assert output
60-
print(output)
6158

6259

6360
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@@ -72,7 +69,6 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
7269
max_tokens=32)
7370

7471
assert output
75-
print(output)
7672

7773

7874
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@@ -92,7 +88,127 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
9288
max_tokens=32)
9389

9490
assert output
95-
print(output)
91+
92+
93+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
94+
def test_on_the_fly_quant_config_dict_json(vllm_runner):
95+
"""Testing on the fly quantization, load_weights integration point,
96+
with config dict serialized to json string
97+
"""
98+
torch._dynamo.reset()
99+
model_name = "facebook/opt-125m"
100+
101+
import json
102+
103+
from torchao.core.config import config_to_dict
104+
from torchao.quantization import (
105+
Float8DynamicActivationFloat8WeightConfig, PerRow)
106+
107+
torchao_quant_config = Float8DynamicActivationFloat8WeightConfig(
108+
granularity=PerRow())
109+
hf_overrides = {
110+
"quantization_config_dict_json":
111+
json.dumps(config_to_dict(torchao_quant_config))
112+
}
113+
with vllm_runner(model_name=model_name,
114+
dtype="bfloat16",
115+
pt_load_map_location="cuda:0",
116+
quantization="torchao",
117+
hf_overrides=hf_overrides) as llm:
118+
output = llm.generate_greedy(["The capital of France is"],
119+
max_tokens=32)
120+
121+
assert output
122+
123+
124+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
125+
def test_on_the_fly_quant_config_file(vllm_runner):
126+
"""Testing on the fly quantization, load_weights integration point,
127+
with config file
128+
"""
129+
torch._dynamo.reset()
130+
model_name = "facebook/opt-125m"
131+
import json
132+
from tempfile import NamedTemporaryFile
133+
134+
from torchao.core.config import config_to_dict
135+
from torchao.quantization import (
136+
Float8DynamicActivationFloat8WeightConfig, PerRow)
137+
138+
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
139+
140+
with NamedTemporaryFile(mode="w", delete=False) as f:
141+
f.write(json.dumps(config_to_dict(config)))
142+
# close the file to save it
143+
f.close()
144+
config_file_name = str(f.name)
145+
146+
hf_overrides = {"quantization_config_file": config_file_name}
147+
with vllm_runner(model_name=model_name,
148+
dtype="bfloat16",
149+
pt_load_map_location="cuda:0",
150+
quantization="torchao",
151+
hf_overrides=hf_overrides) as llm:
152+
output = llm.generate_greedy(["The capital of France is"],
153+
max_tokens=32)
154+
155+
assert output
156+
157+
158+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
159+
def test_reload_weights():
160+
import json
161+
162+
from torchao.core.config import config_to_dict
163+
from torchao.quantization import (
164+
Float8DynamicActivationFloat8WeightConfig, PerRow)
165+
166+
from vllm import LLM, SamplingParams
167+
168+
torchao_quant_config = Float8DynamicActivationFloat8WeightConfig(
169+
granularity=PerRow())
170+
171+
hf_overrides = {
172+
"quantization_config_dict_json":
173+
json.dumps(config_to_dict(torchao_quant_config))
174+
}
175+
176+
llm = LLM(
177+
model="Qwen/Qwen3-0.6B",
178+
dtype="bfloat16",
179+
load_format="dummy",
180+
enforce_eager=True,
181+
quantization="torchao",
182+
hf_overrides=hf_overrides,
183+
)
184+
# Update load format from `dummy` to `auto`
185+
llm.collective_rpc("update_config",
186+
args=({
187+
"load_config": {
188+
"load_format": "auto"
189+
}
190+
}, ))
191+
# Now reload real weights inplace
192+
llm.collective_rpc("reload_weights")
193+
prompts = [
194+
"Hello, my name is",
195+
"The president of the United States is",
196+
"The capital of France is",
197+
"The future of AI is",
198+
]
199+
# Create a sampling params object.
200+
sampling_params = SamplingParams(temperature=0, top_p=0.95)
201+
outputs = llm.generate(prompts, sampling_params)
202+
# make sure it runs
203+
for output in outputs:
204+
generated_text = output.outputs[0].text
205+
assert generated_text
206+
# can also uncomment locally to make sure the generated
207+
# output makes sense
208+
# prompt = output.prompt
209+
# print(f"Prompt: {prompt!r}")
210+
# print(f"Output: {generated_text!r}")
211+
# print("-" * 60)
96212

97213

98214
if __name__ == "__main__":

vllm/model_executor/layers/quantization/torchao.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import json
34
from typing import Any, Optional
45

56
import torch
@@ -40,7 +41,8 @@ class TorchAOConfig(QuantizationConfig):
4041

4142
def __init__(self,
4243
torchao_config,
43-
skip_modules: Optional[list[str]] = None) -> None:
44+
skip_modules: Optional[list[str]] = None,
45+
is_checkpoint_torchao_serialized: bool = False) -> None:
4446
"""
4547
# TorchAO quantization relies on tensor subclasses. In order,
4648
# to enable proper caching this needs standalone compile
@@ -58,9 +60,11 @@ def __init__(self,
5860
super().__init__()
5961
self.torchao_config = torchao_config
6062
self.skip_modules = skip_modules or []
63+
self.is_checkpoint_torchao_serialized = is_checkpoint_torchao_serialized
6164

6265
def __repr__(self) -> str:
63-
return f"TorchAOConfig({self.torchao_config})"
66+
return f"TorchAOConfig({self.torchao_config=}, {self.skip_modules=}, " \
67+
f"{self.is_checkpoint_torchao_serialized=})"
6468

6569
def get_name(self) -> QuantizationMethods:
6670
return "torchao"
@@ -74,7 +78,10 @@ def get_min_capability(cls) -> int:
7478

7579
@staticmethod
7680
def get_config_filenames() -> list[str]:
77-
return ["config.json"]
81+
"""torchao doesn't require additional config files, we use
82+
`config.json` from huggingface: `model_config.hf_config`
83+
"""
84+
return []
7885

7986
@classmethod
8087
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
@@ -87,6 +94,10 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
8794
"`pip install torchao>=0.10.0` to use torchao quantization."
8895
) from err
8996

97+
quant_method = cls.get_from_keys_or(config, ["quant_method"], None)
98+
is_checkpoint_torchao_serialized = (quant_method is not None
99+
and "torchao" in quant_method)
100+
90101
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
91102
assert hf_config is not None, "quant_type must be specified"
92103
assert len(hf_config) == 1 and "default" in hf_config, (
@@ -110,7 +121,38 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
110121
if layer_cfg is None:
111122
skip_modules.append(layer)
112123

113-
return cls(ao_config, skip_modules)
124+
return cls(ao_config, skip_modules, is_checkpoint_torchao_serialized)
125+
126+
@classmethod
127+
def from_config_file(cls, config_file: str) -> "TorchAOConfig":
128+
"""Initialize class from a config file. Example:
129+
```
130+
config = (
131+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
132+
)
133+
fn = "torchao_config.json"
134+
135+
with open(fn, "w") as f:
136+
f.write(json.dumps(config_to_dict(config)))
137+
```
138+
"""
139+
with open(config_file) as f:
140+
f.seek(0)
141+
f_read = f.read()
142+
config_dict = json.loads(f_read)
143+
144+
hf_config = {"quant_type": {"default": config_dict}}
145+
return cls.from_config(hf_config)
146+
147+
@classmethod
148+
def from_config_dict_json(cls, config_dict_json: str) -> "TorchAOConfig":
149+
"""Iniitalize class from a config_dict json string, got from
150+
torchao_config_object = some AOBaseConfig object
151+
json.dumps(config_to_dict(torchao_config_object))
152+
"""
153+
config_dict = json.loads(config_dict_json)
154+
hf_config = {"quant_type": {"default": config_dict}}
155+
return cls.from_config(hf_config)
114156

115157
def get_quant_method(self, layer: torch.nn.Module,
116158
prefix: str) -> Optional["QuantizeMethodBase"]:
@@ -128,7 +170,9 @@ def get_quant_method(self, layer: torch.nn.Module,
128170
c = module_fqn_to_config.get(
129171
module_fqn) or module_fqn_to_config.get("_default", None)
130172
if c is not None:
131-
current_torchao_config = TorchAOConfig(c, self.skip_modules)
173+
current_torchao_config = TorchAOConfig(
174+
c, self.skip_modules,
175+
self.is_checkpoint_torchao_serialized)
132176
return TorchAOLinearMethod(current_torchao_config)
133177
else:
134178
return UnquantizedLinearMethod()
@@ -172,7 +216,7 @@ class TorchAOLinearMethod(LinearMethodBase):
172216
"""Linear method for torchao.
173217
174218
Args:
175-
quant_config: The torchao quantization config, a string that encodes
219+
quant_config: The torchao quantization config, a string that encodes
176220
the type of quantization and all relevant arguments.
177221
"""
178222

@@ -197,8 +241,9 @@ def create_weights(
197241
),
198242
requires_grad=False,
199243
)
200-
weight = torchao_quantize_param_data(weight,
201-
self.quant_config.torchao_config)
244+
if self.quant_config.is_checkpoint_torchao_serialized:
245+
weight = torchao_quantize_param_data(
246+
weight, self.quant_config.torchao_config)
202247

203248
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
204249

@@ -212,3 +257,14 @@ def apply(
212257
bias: Optional[torch.Tensor] = None,
213258
) -> torch.Tensor:
214259
return F.linear(x, layer.weight, bias)
260+
261+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
262+
if self.quant_config.is_checkpoint_torchao_serialized:
263+
return
264+
265+
# quantize the weight on the fly if the checkpoint is not already
266+
# quantized by torchao
267+
weight = torchao_quantize_param_data(layer.weight,
268+
self.quant_config.torchao_config)
269+
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
270+
layer.register_parameter("weight", weight)

vllm/model_executor/model_loader/default_loader.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,35 @@ def download_model(self, model_config: ModelConfig) -> None:
261261
def load_weights(self, model: nn.Module,
262262
model_config: ModelConfig) -> None:
263263
weights_to_load = {name for name, _ in model.named_parameters()}
264-
loaded_weights = model.load_weights(
265-
self.get_all_weights(model_config, model))
264+
265+
# if we don't have `model.weight_metadata_and_attr_saved` defined and
266+
# set to True, it means that this is either offline quantization case
267+
# or the first run of online quantization
268+
# see online_quantization.py for detailed notes
269+
offline_quantization_or_first_run_of_online_quantization = not getattr(
270+
model, "weight_metadata_and_attr_saved", False)
271+
272+
if model_config.quantization is None:
273+
# model is not quantized
274+
loaded_weights = model.load_weights(
275+
self.get_all_weights(model_config, model))
276+
elif offline_quantization_or_first_run_of_online_quantization:
277+
# case 1: offline quantized checkpoint
278+
# case 2: Step I1 first run of weight loading with
279+
# online quantization
280+
# see online_quantization.py for detailed notes
281+
loaded_weights = model.load_weights(
282+
self.get_all_weights(model_config, model))
283+
else:
284+
# to avoid circular dependency
285+
from vllm.model_executor.model_loader.online_quantization import (
286+
load_weights_and_online_quantize)
287+
288+
# subsequent runs of weight loading with online
289+
# quantization
290+
loaded_weights = load_weights_and_online_quantize(
291+
self, model, model_config)
292+
266293
self.counter_after_loading_weights = time.perf_counter()
267294
logger.info(
268295
"Loading weights took %.2f seconds",

0 commit comments

Comments
 (0)