Skip to content

Commit 0677428

Browse files
committed
nfp4
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 1c85a66 commit 0677428

File tree

6 files changed

+195
-30
lines changed

6 files changed

+195
-30
lines changed

src/llmcompressor/entrypoints/model_free/__init__.py

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import shutil
3+
from collections import defaultdict
34
from concurrent.futures import ThreadPoolExecutor, as_completed
45
from pathlib import Path
56
from typing import Optional
@@ -13,13 +14,19 @@
1314

1415
from llmcompressor.entrypoints.model_free.helpers import (
1516
gpu_if_available,
17+
validate_safetensors_index,
1618
validate_scheme,
1719
)
1820
from llmcompressor.entrypoints.model_free.lifecycle import (
19-
calibrate_weights,
21+
calibrate_global_scale,
22+
calibrate_scale_zp,
2023
compress_module,
2124
initialize_quantized_linear,
2225
)
26+
from llmcompressor.entrypoints.model_free.microscale import (
27+
get_fused_names,
28+
is_microscale_scheme,
29+
)
2330
from llmcompressor.entrypoints.model_free.model_utils import (
2431
get_checkpoint_files,
2532
is_weights_file,
@@ -55,16 +62,20 @@ def model_free_ptq(
5562
model_files = get_checkpoint_files(model_stub)
5663
scheme_name, scheme = validate_scheme(scheme)
5764
device = gpu_if_available(device)
65+
validate_safetensors_index(model_files, scheme)
5866

5967
# 0. collect safetensors files, copy files
6068
jobs = []
69+
job_fn = (
70+
_process_file
71+
if not is_microscale_scheme(scheme)
72+
else _process_file_microscale_scheme
73+
)
6174
for file_path, resolved_path in model_files:
6275
save_path = Path(save_directory) / file_path
6376

6477
if file_path.endswith("safetensors"):
65-
jobs.append(
66-
(_process_file, resolved_path, save_path, scheme, ignore, device)
67-
)
78+
jobs.append((job_fn, resolved_path, save_path, scheme, ignore, device))
6879

6980
else:
7081
if is_weights_file(file_path):
@@ -108,6 +119,7 @@ def _process_file(
108119
ignored
109120
:param device: device used to quantize and compress weights
110121
"""
122+
assert not is_microscale_scheme(scheme), "Use `_process_file_microscale_scheme`"
111123
tensors = load_file(file_path)
112124

113125
for name in list(tensors.keys()):
@@ -121,7 +133,66 @@ def _process_file(
121133
module = initialize_quantized_linear(tensors[name], scheme, device)
122134

123135
# 2. calibrate weight qparams
124-
calibrate_weights(module)
136+
calibrate_scale_zp(module)
137+
138+
# 3. compress module using qparams
139+
compress_module(module)
140+
141+
# 4. save compressed data (on cpu)
142+
del tensors[name]
143+
prefix = module_name + "."
144+
for key, value in module.state_dict(prefix=prefix).items():
145+
tensors[key] = value.to("cpu")
146+
147+
save_file(tensors, save_path)
148+
total_size = sum(tensor.nbytes for tensor in tensors.values())
149+
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
150+
return total_size, weight_map
151+
152+
153+
def _process_file_microscale_scheme(
154+
file_path: str | os.PathLike,
155+
save_path: str | os.PathLike,
156+
scheme: QuantizationScheme,
157+
ignore: str | list[str],
158+
device: str | torch.device,
159+
) -> tuple[int, dict[str, str]]:
160+
"""
161+
Quantize and compress tensors in a given safetensors file
162+
163+
:param file_path: safetensors file to process
164+
:param save_path: save path of file with quantized weights
165+
:param scheme: quantization scheme to apply to tensors
166+
:param ignore: modules to ignore. Modules ending with "norm" are automatically
167+
ignored
168+
:param device: device used to quantize and compress weights
169+
"""
170+
assert is_microscale_scheme(scheme), "Use `_process_file` for non microscale scheme"
171+
tensors = load_file(file_path)
172+
fused_names = get_fused_names(tensors)
173+
fused_names_to_parent = {
174+
name: prefix for prefix, names in fused_names.items() for name in names
175+
}
176+
fused_parent_submodules = defaultdict(dict)
177+
178+
for name in list(tensors.keys()):
179+
module_name, param_name = name.rsplit(".", 1)
180+
is_linear_weight = param_name == "weight" and not module_name.endswith("norm")
181+
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
182+
if not is_linear_weight or is_ignored:
183+
continue
184+
185+
# 1. initialize module with qparams (on device)
186+
module = initialize_quantized_linear(tensors[name], scheme, device)
187+
188+
# 2. calibrate weight qparams. Delay scale/zp calibration for fused modules
189+
calibrate_global_scale(module)
190+
if name in fused_names_to_parent:
191+
fused_parent = fused_names_to_parent[name]
192+
fused_parent_submodules[fused_parent][name] = module
193+
continue
194+
195+
calibrate_scale_zp(module)
125196

126197
# 3. compress module using qparams
127198
compress_module(module)
@@ -132,6 +203,28 @@ def _process_file(
132203
for key, value in module.state_dict(prefix=prefix).items():
133204
tensors[key] = value.to("cpu")
134205

206+
# compress and save miscroscale fused modules
207+
for parent_name, named_modules in fused_parent_submodules.items():
208+
# 2.1. fuse global scales
209+
global_scales = [m.weight_global_scale for m in named_modules.values()]
210+
fused_global_scale = torch.min(torch.cat(global_scales, dim=0))
211+
212+
for name, module in named_modules.items():
213+
module_name, param_name = name.rsplit(".", 1)
214+
module.weight_global_scale.data.copy_(fused_global_scale)
215+
216+
# 2.2. finish calibration with fused global scales
217+
calibrate_scale_zp(module)
218+
219+
# 3. compress module using qparams
220+
compress_module(module)
221+
222+
# 4. save compressed data (on cpu)
223+
del tensors[name]
224+
prefix = module_name + "."
225+
for key, value in module.state_dict(prefix=prefix).items():
226+
tensors[key] = value.to("cpu")
227+
135228
save_file(tensors, save_path)
136229
total_size = sum(tensor.nbytes for tensor in tensors.values())
137230
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}

src/llmcompressor/entrypoints/model_free/helpers.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from typing import Optional
1+
import json
22

33
import torch
4-
from compressed_tensors.quantization import QuantizationScheme, preset_name_to_scheme
4+
from compressed_tensors.quantization import (
5+
QuantizationScheme,
6+
preset_name_to_scheme,
7+
)
58
from compressed_tensors.utils import getattr_chain
6-
from compressed_tensors.utils.match import _match_name
79
from loguru import logger
810

9-
__all__ = ["validate_scheme", "gpu_if_available", "is_match_name"]
11+
from .microscale import get_fused_names, is_microscale_scheme
12+
13+
__all__ = ["validate_scheme", "gpu_if_available"]
1014

1115

1216
def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme]:
@@ -48,6 +52,34 @@ def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme
4852
return scheme_name, scheme
4953

5054

55+
def validate_safetensors_index(
56+
model_files: list[tuple[str, str]], scheme: QuantizationScheme
57+
):
58+
resolved_paths = [
59+
resolved_path
60+
for file_path, resolved_path in model_files
61+
if file_path.endswith("safetensors.index.json")
62+
]
63+
if len(resolved_paths) <= 0:
64+
return
65+
resolved_path = resolved_paths[0]
66+
67+
if is_microscale_scheme(scheme):
68+
with open(resolved_path, "r") as file:
69+
weight_map: dict[str, str] = json.load(file)["weight_map"]
70+
71+
fused_names = get_fused_names(weight_map)
72+
for submodule_names in fused_names.values():
73+
file_names = [weight_map[name] for name in submodule_names]
74+
if not all(file_name == file_names[0] for file_name in file_names):
75+
raise NotImplementedError(
76+
"When using a microscale scheme (NVFP4, MXFP4), global scales "
77+
"will be fused. Current implmentation requires that all fused "
78+
"modules (attention and non-moe mlp) be stored in the same file. "
79+
f"Instead, got {submodule_names}\n\n {file_names}"
80+
)
81+
82+
5183
def gpu_if_available(device: torch.device | str | None) -> torch.device:
5284
if device is not None:
5385
return torch.device(device)
@@ -61,15 +93,3 @@ def gpu_if_available(device: torch.device | str | None) -> torch.device:
6193
else:
6294
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")
6395
return torch.device("cpu")
64-
65-
66-
def is_match_name(
67-
name: str, targets: list[str], ignore: Optional[str | list[str]] = None
68-
) -> bool:
69-
targets = targets if isinstance(targets, list) else [targets]
70-
ignore = ignore if isinstance(ignore, list) else [ignore]
71-
72-
matches_target = any(_match_name(name, target) for target in targets)
73-
matches_ignore = any(_match_name(name, ign) for ign in ignore)
74-
75-
return matches_target and not matches_ignore

src/llmcompressor/entrypoints/model_free/lifecycle.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from compressed_tensors.config.format import _get_quant_compression_format
44
from compressed_tensors.quantization import (
55
QuantizationScheme,
6-
QuantizationStrategy,
76
initialize_module_for_quantization,
87
)
98

@@ -17,7 +16,8 @@
1716

1817
__all__ = [
1918
"initialize_quantized_linear",
20-
"calibrate_weights",
19+
"calibrate_global_scale",
20+
"calibrate_scale_zp",
2121
"compress_module",
2222
]
2323

@@ -35,15 +35,17 @@ def initialize_quantized_linear(
3535
return module
3636

3737

38-
def calibrate_weights(module: torch.nn.Linear):
39-
scheme: QuantizationScheme = getattr(module, "quantization_scheme")
38+
def calibrate_global_scale(module: torch.nn.Linear):
4039
initialize_observer(module, "weight")
40+
apply_calibration_status(module)
41+
update_weight_global_scale(module)
42+
freeze_module_quantization(module)
4143

44+
45+
def calibrate_scale_zp(module: torch.nn.Linear):
46+
initialize_observer(module, "weight")
4247
apply_calibration_status(module)
43-
if scheme.weights.strategy == QuantizationStrategy.TENSOR_GROUP:
44-
update_weight_global_scale(module)
4548
update_weight_zp_scale(module)
46-
4749
freeze_module_quantization(module)
4850

4951

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
3+
4+
__all__ = ["get_fused_names", "is_microscale_scheme"]
5+
6+
7+
def is_microscale_scheme(scheme: QuantizationScheme) -> bool:
8+
assert scheme.weights is not None
9+
return scheme.weights.strategy == QuantizationStrategy.TENSOR_GROUP
10+
11+
12+
def get_fused_names(tensors: dict[str, torch.Tensor]) -> dict[str, list[str]]:
13+
fused_names = {}
14+
15+
for name in tensors:
16+
parts = name.rsplit(".")
17+
if len(parts) < 3:
18+
continue
19+
20+
parent, module, param = parts[-3:]
21+
22+
if (
23+
("attn" in parent or "attention" in parent)
24+
and module == "q_proj"
25+
and param == "weight"
26+
):
27+
parent_name = ".".join((*parts[:-3], parent))
28+
q_name = ".".join((parent_name, "q_proj", param))
29+
k_name = ".".join((parent_name, "k_proj", param))
30+
v_name = ".".join((parent_name, "v_proj", param))
31+
32+
submodule_names = [q_name, k_name, v_name]
33+
34+
if all(name in tensors for name in submodule_names):
35+
assert parent_name not in fused_names
36+
fused_names[parent_name] = submodule_names
37+
38+
if "mlp" in parent and module == "gate_proj" and param == "weight":
39+
parent_name = ".".join((*parts[:-3], parent))
40+
gate_name = ".".join((parent_name, "gate_proj", param))
41+
up_name = ".".join((parent_name, "up_proj", param))
42+
43+
submodule_names = [gate_name, up_name]
44+
45+
if all(name in tensors for name in submodule_names):
46+
assert parent_name not in fused_names
47+
fused_names[parent_name] = submodule_names
48+
49+
return fused_names

src/llmcompressor/entrypoints/model_free/model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def is_weights_file(file_name: str) -> bool:
1818
return any(file_name.endswith(suffix) for suffix in weights_files)
1919

2020

21-
def get_checkpoint_files(model_stub: str | os.PathLike) -> list[str]:
21+
def get_checkpoint_files(model_stub: str | os.PathLike) -> list[tuple[str, str]]:
2222
# In the future, this function can accept and pass download kwargs to cached_file
2323

2424
if os.path.exists(model_stub):

tests/llmcompressor/pipelines/test_model_free_ptq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def _get_tiny_block_quant():
4141

4242
@requires_gpu
4343
@pytest.mark.parametrize(
44-
"scheme", [_get_tiny_w4a16_quant(), "FP8_dynamic", _get_tiny_block_quant()]
44+
"scheme",
45+
[_get_tiny_w4a16_quant(), "FP8_dynamic", _get_tiny_block_quant(), "NVFP4A16"],
4546
)
4647
def test_model_free_ptq_matches_oneshot(scheme, tmp_path):
4748
model = "nm-testing/tinysmokellama-3.2"

0 commit comments

Comments
 (0)