|
15 | 15 |
|
16 | 16 | import dataclasses |
17 | 17 | import os |
| 18 | +import re |
18 | 19 | import warnings |
| 20 | +from collections import defaultdict |
19 | 21 | from contextlib import contextmanager |
20 | 22 | from typing import Any |
21 | 23 |
|
|
30 | 32 | from modelopt.torch.utils.dataset_utils import get_dataset_dataloader |
31 | 33 |
|
32 | 34 |
|
| 35 | +def convert_amax_hf2vllm( |
| 36 | + hf_state_dict: dict[str, torch.Tensor], |
| 37 | +) -> dict[str, torch.Tensor]: |
| 38 | + """ |
| 39 | + Convert amax values from HuggingFace format to vLLM format. |
| 40 | +
|
| 41 | + This function merges: |
| 42 | + - q_proj, k_proj, v_proj amax values into qkv_proj (taking max) |
| 43 | + - gate_proj, up_proj amax values into gate_up_proj (taking max) |
| 44 | +
|
| 45 | + Args: |
| 46 | + hf_state_dict: HuggingFace state dict containing amax values |
| 47 | +
|
| 48 | + Returns: |
| 49 | + vLLM format state dict with merged amax values |
| 50 | + """ |
| 51 | + vllm_state_dict = {} |
| 52 | + |
| 53 | + # Group keys by their base pattern (without the specific projection name) |
| 54 | + merge_groups = defaultdict(list) |
| 55 | + |
| 56 | + for key, value in hf_state_dict.items(): |
| 57 | + if "_amax" not in key: |
| 58 | + # Copy non-amax keys as-is |
| 59 | + vllm_state_dict[key] = value |
| 60 | + continue |
| 61 | + |
| 62 | + # Check if this is a q/k/v projection that needs merging |
| 63 | + qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key) |
| 64 | + if qkv_match: |
| 65 | + base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3) |
| 66 | + merge_groups[base_pattern].append((key, value)) |
| 67 | + continue |
| 68 | + |
| 69 | + # Check if this is a gate/up projection that needs merging |
| 70 | + gate_up_match = "mixer" not in key and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key) |
| 71 | + if gate_up_match: |
| 72 | + base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3) |
| 73 | + merge_groups[base_pattern].append((key, value)) |
| 74 | + continue |
| 75 | + |
| 76 | + # Copy other amax keys as-is (like o_proj, down_proj) |
| 77 | + vllm_state_dict[key] = value |
| 78 | + |
| 79 | + # Merge grouped amax values by taking the maximum |
| 80 | + for merged_key, key_value_pairs in merge_groups.items(): |
| 81 | + if len(key_value_pairs) > 1: |
| 82 | + # Take the maximum across all values for this merged key |
| 83 | + values = [value for _, value in key_value_pairs] |
| 84 | + merged_value = torch.stack(values).max(dim=0)[0] |
| 85 | + vllm_state_dict[merged_key] = merged_value |
| 86 | + print(f"Merged {len(key_value_pairs)} keys into {merged_key}") |
| 87 | + for orig_key, _ in key_value_pairs: |
| 88 | + print(f" - {orig_key}") |
| 89 | + else: |
| 90 | + # Single key, just rename it |
| 91 | + _, value = key_value_pairs[0] |
| 92 | + vllm_state_dict[merged_key] = value |
| 93 | + |
| 94 | + return vllm_state_dict |
| 95 | + |
| 96 | + |
33 | 97 | @contextmanager |
34 | 98 | def disable_compilation(model): |
35 | 99 | do_not_compile = True |
@@ -154,8 +218,17 @@ def calibrate_loop(model: Any = None) -> None: |
154 | 218 | if amax_file_path: |
155 | 219 | print(f"Loading amax values from {amax_file_path}") |
156 | 220 | saved_amax_dict = torch.load(amax_file_path) |
157 | | - current_state_dict = model.state_dict() |
| 221 | + # convert amax keys to vLLM format |
| 222 | + if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): |
| 223 | + saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict) |
| 224 | + saved_amax_dict = { |
| 225 | + key.replace("quantizer_amax", "quantizer._amax"): value |
| 226 | + for key, value in saved_amax_dict.items() |
| 227 | + if key.endswith("quantizer_amax") |
| 228 | + } |
| 229 | + saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict) |
158 | 230 |
|
| 231 | + current_state_dict = model.state_dict() |
159 | 232 | # Count amax keys in checkpoint and model |
160 | 233 | checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")] |
161 | 234 | model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")] |
|
0 commit comments