Skip to content

Commit 8db6c6b

Browse files
committed
nvfp4a16
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 9e1a040 commit 8db6c6b

File tree

14 files changed

+626
-143
lines changed

14 files changed

+626
-143
lines changed

examples/model_free_ptq/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,53 @@
1313
In `kimi_k2_thinking_fp8_block.py`, we call `model_free_ptq` by providing a `scheme` and `ignore` list, similar to how we provide reicpes to `oneshot` calls. In the case of Kimi-K2 Thinking, we apply the `FP8_BLOCK` scheme and ignore layers that are incompatible with a block_size of 128 (specifically, `kv_a_proj_with_mqa` and `q_a_proj`).
1414

1515
In contrast to `oneshot`, we expect the model stub or pathway string to be directly passed in, as opposed to first being loaded through transformers. Once complete, the model is compressed using compressed-tensors and saved to `SAVE_DIR`.
16+
17+
To get started, simply call `model_free_ptq` with your desired model stub and save directory
18+
```python
19+
model_free_ptq(
20+
model_stub="unsloth/Kimi-K2-Thinking-BF16",
21+
save_directory="Kimi-K2-Thinking-FP8-BLOCK",
22+
scheme="FP8_BLOCK",
23+
ignore=[
24+
"re:.*gate$",
25+
"lm_head",
26+
"re:.*kv_a_proj_with_mqa$",
27+
"re:.*q_a_proj$",
28+
"model.embed_tokens",
29+
],
30+
max_workers=15,
31+
device="cuda:0",
32+
)
33+
34+
```
35+
36+
37+
# Quantizing models to NVFP4A16/ MXFP4A16
38+
39+
Using `model_free_ptq` to quantizing models with microscale schemes (NVFP4/MXFP4) is the same as quantizing models using non-microscale schemes, except for one additional step. That extra step is that the safetensors in the model files must be reindexed in order to guarantee that fused modules (qkv, gate_up) end up in the same safetensors files, which assists `model_free_ptq` in fusing global scales.
40+
41+
First, apply `llmcompressor.reindex_fused_weights` from the command line entrypoint
42+
```bash
43+
llmcompressor.reindex_fused_weights \
44+
unsloth/Kimi-K2-Thinking-BF16 \
45+
Kimi-K2-Thinking-BF16-reindexed \
46+
--num_workers=10
47+
```
48+
49+
Then, call `model_free_ptq` on the reindex files
50+
```python
51+
model_free_ptq(
52+
model_stub="Kimi-K2-Thinking-BF16-reindexed",
53+
save_directory="Kimi-K2-Thinking-BF16-NVFP4A16",
54+
scheme="FP8_BLOCK",
55+
ignore=[
56+
"re:.*gate$",
57+
"lm_head",
58+
"re:.*kv_a_proj_with_mqa$",
59+
"re:.*q_a_proj$",
60+
"model.embed_tokens",
61+
],
62+
max_workers=15,
63+
device="cuda:0",
64+
)
65+
```

examples/model_free_ptq/kimi_k2_thinking_fp8_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from llmcompressor import model_free_ptq
22

33
MODEL_ID = "unsloth/Kimi-K2-Thinking-BF16"
4-
SAVE_DIR = "Kimi-K2-Thinking-FP8-Block"
4+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-BLOCK"
55

66
# Apply FP8-Block to the model
77
# Once quantized, the model is saved
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
NOTE: Please run the following script before using `model_free_ptq`
3+
4+
This script is used to reindex the safetensors files of a model such that all fused
5+
modules (gate_up, qkv) are in the same safetensors file. This is required by
6+
model_free_ptq for microscale schemes (NVFP4A16, MXFP4A16)
7+
8+
llmcompressor.reindex_fused_weights \
9+
unsloth/Kimi-K2-Thinking-BF16 \
10+
Kimi-K2-Thinking-BF16-reindexed \
11+
--num_workers=10
12+
"""
13+
14+
from llmcompressor import model_free_ptq
15+
16+
MODEL_ID = "unsloth/Kimi-K2-Thinking-BF16"
17+
REINDEX_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-reindexed"
18+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4A16"
19+
20+
# See above notice pertaining to safetensors reindexing
21+
# After running `llmcompressor.reindex_fused_weights`,
22+
# use `model_free_ptq` to apply NVFP4A16 quantization
23+
model_free_ptq(
24+
model_stub=REINDEX_DIR,
25+
save_directory=SAVE_DIR,
26+
scheme="FP8_BLOCK",
27+
ignore=[
28+
"re:.*gate$",
29+
"lm_head",
30+
"re:.*kv_a_proj_with_mqa$",
31+
"re:.*q_a_proj$",
32+
"model.embed_tokens",
33+
],
34+
max_workers=15,
35+
device="cuda:0",
36+
)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def localversion_func(version: ScmVersion) -> str:
184184
entry_points={
185185
"console_scripts": [
186186
"llmcompressor.trace=llmcompressor.transformers.tracing.debug:main",
187+
"llmcompressor.reindex_fused_weights=llmcompressor.entrypoints.model_free.reindex_fused_weights:main",
187188
]
188189
},
189190
python_requires=">=3.10",

src/llmcompressor/entrypoints/model_free/__init__.py

Lines changed: 20 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,28 @@
77
import torch
88
import tqdm
99
from compressed_tensors.quantization import QuantizationScheme
10-
from compressed_tensors.utils.match import _match_name
1110
from loguru import logger
12-
from safetensors.torch import load_file, save_file
1311

14-
from llmcompressor.entrypoints.model_free.helpers import (
15-
gpu_if_available,
16-
validate_scheme,
17-
)
18-
from llmcompressor.entrypoints.model_free.lifecycle import (
19-
calibrate_weights,
20-
compress_module,
21-
initialize_quantized_linear,
12+
from llmcompressor.entrypoints.model_free.helpers import gpu_if_available
13+
from llmcompressor.entrypoints.model_free.microscale import (
14+
is_microscale_scheme,
2215
)
2316
from llmcompressor.entrypoints.model_free.model_utils import (
2417
get_checkpoint_files,
2518
is_weights_file,
2619
)
20+
from llmcompressor.entrypoints.model_free.process import (
21+
process_file,
22+
process_file_microscale_scheme,
23+
)
2724
from llmcompressor.entrypoints.model_free.save_utils import (
2825
update_config,
2926
update_safetensors_index,
3027
)
28+
from llmcompressor.entrypoints.model_free.validate import (
29+
validate_safetensors_index,
30+
validate_scheme,
31+
)
3132

3233
__all__ = ["model_free_ptq"]
3334

@@ -55,20 +56,24 @@ def model_free_ptq(
5556
model_files = get_checkpoint_files(model_stub)
5657
scheme_name, scheme = validate_scheme(scheme)
5758
device = gpu_if_available(device)
59+
validate_safetensors_index(model_files, scheme)
5860

5961
# 0. collect safetensors files, copy files
6062
jobs = []
61-
for file_path, resolved_path in model_files:
63+
job_fn = (
64+
process_file
65+
if not is_microscale_scheme(scheme)
66+
else process_file_microscale_scheme
67+
)
68+
for file_path, resolved_path in model_files.items():
6269
save_path = Path(save_directory) / file_path
6370

6471
if file_path.endswith("safetensors"):
65-
jobs.append(
66-
(_process_file, resolved_path, save_path, scheme, ignore, device)
67-
)
72+
jobs.append((job_fn, resolved_path, save_path, scheme, ignore, device))
6873

6974
else:
7075
if is_weights_file(file_path):
71-
logger.warning(f"Skipping weights file {file_path}")
76+
logger.warning(f"Skip processing for weights file {file_path}")
7277
save_path.parent.mkdir(parents=True, exist_ok=True)
7378
logger.info(f"Copying {file_path} {save_path}")
7479
shutil.copyfile(resolved_path, save_path)
@@ -89,50 +94,3 @@ def model_free_ptq(
8994
# 5. update config and safetensors index
9095
update_config(save_directory, scheme_name, scheme, ignore)
9196
update_safetensors_index(save_directory, total_size, weight_map)
92-
93-
94-
def _process_file(
95-
file_path: str | os.PathLike,
96-
save_path: str | os.PathLike,
97-
scheme: QuantizationScheme,
98-
ignore: str | list[str],
99-
device: str | torch.device,
100-
) -> tuple[int, dict[str, str]]:
101-
"""
102-
Quantize and compress tensors in a given safetensors file
103-
104-
:param file_path: safetensors file to process
105-
:param save_path: save path of file with quantized weights
106-
:param scheme: quantization scheme to apply to tensors
107-
:param ignore: modules to ignore. Modules ending with "norm" are automatically
108-
ignored
109-
:param device: device used to quantize and compress weights
110-
"""
111-
tensors = load_file(file_path)
112-
113-
for name in list(tensors.keys()):
114-
module_name, param_name = name.rsplit(".", 1)
115-
is_linear_weight = param_name == "weight" and not module_name.endswith("norm")
116-
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
117-
if not is_linear_weight or is_ignored:
118-
continue
119-
120-
# 1. initialize module with qparams (on device)
121-
module = initialize_quantized_linear(tensors[name], scheme, device)
122-
123-
# 2. calibrate weight qparams
124-
calibrate_weights(module)
125-
126-
# 3. compress module using qparams
127-
compress_module(module)
128-
129-
# 4. save compressed data (on cpu)
130-
del tensors[name]
131-
prefix = module_name + "."
132-
for key, value in module.state_dict(prefix=prefix).items():
133-
tensors[key] = value.to("cpu")
134-
135-
save_file(tensors, save_path)
136-
total_size = sum(tensor.nbytes for tensor in tensors.values())
137-
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
138-
return total_size, weight_map
Lines changed: 81 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,25 @@
1-
from typing import Optional
1+
import os
2+
from collections import defaultdict
3+
from typing import Mapping, TypeVar
24

35
import torch
4-
from compressed_tensors.quantization import QuantizationScheme, preset_name_to_scheme
5-
from compressed_tensors.utils import getattr_chain
66
from compressed_tensors.utils.match import _match_name
77
from loguru import logger
8+
from transformers.file_utils import CONFIG_NAME
89

9-
__all__ = ["validate_scheme", "gpu_if_available", "is_match_name"]
10+
__all__ = [
11+
"gpu_if_available",
12+
"find_safetensors_index_path",
13+
"find_config_path",
14+
"find_safetensors_index_file",
15+
"match_names_set_eager",
16+
"MatchedNamesSet",
17+
"invert_mapping",
18+
]
1019

11-
12-
def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme]:
13-
# treat strings as preset schemes
14-
if isinstance(scheme, str):
15-
scheme_name, scheme = scheme, preset_name_to_scheme(scheme, [])
16-
else:
17-
scheme_name = "config_group_0"
18-
19-
# weight quantization must be provided
20-
if scheme.weights is None:
21-
raise ValueError(
22-
"Must provide a weights quanitization scheme to perform weights-only PTQ"
23-
)
24-
25-
# activation quantization must be dynamic
26-
input_dynamic = getattr_chain(scheme, "input_activations.dynamic", True)
27-
output_dynamic = getattr_chain(scheme, "output_activations.dynamic", True)
28-
if input_dynamic is not True or output_dynamic is not True:
29-
raise ValueError(
30-
"Model Free PTQ cannot calibrate activations. "
31-
"Please use `oneshot` instead."
32-
)
33-
34-
# override with static observers
35-
# Remove after https://github.com/vllm-project/compressed-tensors/pull/489
36-
if scheme.weights.observer in ("minmax", "mse"):
37-
new_observer = f"static_{scheme.weights.observer}"
38-
logger.warning(
39-
f"Scheme uses {scheme.weights.observer} weight observer. "
40-
f"Using {new_observer} instead"
41-
)
42-
scheme.weights.observer = new_observer
43-
44-
# target all modules; filter by ignore list
45-
# technically this should be "re:.*", but vllm's
46-
# ct moe layer has a hard coded check for "Linear"
47-
scheme.targets = ["Linear"]
48-
return scheme_name, scheme
20+
KeyType = TypeVar("K")
21+
ValueType = TypeVar("V")
22+
MatchedNamesSet = dict[str, str | None]
4923

5024

5125
def gpu_if_available(device: torch.device | str | None) -> torch.device:
@@ -63,13 +37,70 @@ def gpu_if_available(device: torch.device | str | None) -> torch.device:
6337
return torch.device("cpu")
6438

6539

66-
def is_match_name(
67-
name: str, targets: list[str], ignore: Optional[str | list[str]] = None
68-
) -> bool:
69-
targets = targets if isinstance(targets, list) else [targets]
70-
ignore = ignore if isinstance(ignore, list) else [ignore]
40+
def find_safetensors_index_path(save_directory: str | os.PathLike) -> str | None:
41+
for file_name in os.listdir(save_directory):
42+
if file_name.endswith("safetensors.index.json"):
43+
return os.path.join(save_directory, file_name)
44+
45+
return None
46+
47+
48+
def find_config_path(save_directory: str | os.PathLike) -> str | None:
49+
for file_name in os.listdir(save_directory):
50+
if file_name in (CONFIG_NAME, "params.json"):
51+
return os.path.join(save_directory, file_name)
52+
53+
return None
54+
55+
56+
def find_safetensors_index_file(model_files: dict[str, str]) -> str | None:
57+
for file_path, resolved_path in model_files.items():
58+
if file_path.endswith("safetensors.index.json"):
59+
return resolved_path
60+
61+
return None
62+
63+
64+
def match_names_set_eager(
65+
names: set[str] | list[str],
66+
targets: set[str] | list[str],
67+
return_unmatched: bool = True,
68+
) -> list[MatchedNamesSet] | tuple[list[MatchedNamesSet], MatchedNamesSet]:
69+
matched_sets = []
70+
matches = dict.fromkeys(targets, None)
71+
72+
for name in names:
73+
# match until we get a full set
74+
for target in targets:
75+
if _match_name(name, target):
76+
if matches[target] is None:
77+
matches[target] = name
78+
else:
79+
# matched target twice without completing a set
80+
raise ValueError(
81+
f"Matched a {target} twice before "
82+
f"completing set ({matches[target]}, {name})"
83+
)
84+
85+
# once we have a full set, yield and reset
86+
if all((matches[target] is not None for target in targets)):
87+
matched_sets.append(matches)
88+
matches = dict.fromkeys(targets, None)
89+
90+
unmatched_set = matches if any((v is not None for v in matches.values())) else None
91+
92+
if return_unmatched:
93+
return matched_sets, unmatched_set
94+
else:
95+
return matched_sets
96+
97+
98+
def invert_mapping(
99+
mapping: Mapping[KeyType, ValueType],
100+
) -> dict[ValueType, list[KeyType]]:
101+
inverse = defaultdict(list)
71102

72-
matches_target = any(_match_name(name, target) for target in targets)
73-
matches_ignore = any(_match_name(name, ign) for ign in ignore)
103+
for key, value in mapping.items():
104+
inverse[value].append(key)
74105

75-
return matches_target and not matches_ignore
106+
return inverse

0 commit comments

Comments
 (0)