-
Notifications
You must be signed in to change notification settings - Fork 198
Convert compressed-tensor int4 format to GPTQ int4 format #590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Edwardf0t1
wants to merge
3
commits into
main
Choose a base branch
from
zhiyu/ct-int4-to-gptq-int4
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
9dfa893
add a script to convert compressed-tensor int4 format to GPTQ int4 fo…
Edwardf0t1 514e300
add a script to convert compressed-tensor int4 format to GPTQ int4 fo…
Edwardf0t1 574ca5d
add a script to convert compressed-tensor int4 format to GPTQ int4 fo…
Edwardf0t1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,371 @@ | ||
| #!/usr/bin/env python3 | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Convert Kimi-K2-Thinking compressed-tensors checkpoint to TRT-LLM GPTQ format. | ||
|
|
||
| The source checkpoint uses: | ||
| - int32 packing: 8 int4 values per int32 | ||
| - group_size: 32 | ||
| - symmetric quantization | ||
|
|
||
| TRT-LLM GPTQ format uses: | ||
| - int32 packing: 8 int4 values per int32 (same) | ||
| - Requires .qweight, .scales, .qzeros tensors | ||
| """ | ||
|
|
||
| import argparse | ||
| import json | ||
| from pathlib import Path | ||
|
|
||
| import safetensors | ||
| import safetensors.torch | ||
| import torch | ||
| from tqdm import tqdm | ||
|
|
||
|
|
||
| def unpack_int32_to_int4(weight_packed: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Unpack int32 tensor containing 8 int4 values into int8 tensor. | ||
|
|
||
| Args: | ||
| weight_packed: Shape (N, K/8) dtype int32 | ||
|
|
||
| Returns: | ||
| unpacked: Shape (N, K) dtype int8 with values in range [-8, 7] | ||
| """ | ||
| # Convert int32 to uint8 view to extract nibbles | ||
| w_packed_uint8 = weight_packed.contiguous().view(torch.uint8) | ||
|
|
||
| # Each int32 = 4 bytes, each byte has 2 int4 values | ||
| # So shape (N, K/8) int32 -> (N, K/8, 4) uint8 -> (N, K/2) uint8 | ||
| n, k_div_8 = weight_packed.shape | ||
| w_packed_uint8 = w_packed_uint8.view(n, k_div_8 * 4) | ||
|
|
||
| # Allocate output: (N, K) where K = K_div_8 * 8 | ||
| k = k_div_8 * 8 | ||
| w_unpacked = torch.zeros(n, k, dtype=torch.int8) | ||
|
|
||
| # Extract low and high nibbles | ||
| w_unpacked[:, 0::2] = (w_packed_uint8 & 0x0F).to(torch.int8) | ||
| w_unpacked[:, 1::2] = (w_packed_uint8 >> 4).to(torch.int8) | ||
|
|
||
| # Convert from uint4 [0, 15] to int4 [-8, 7] | ||
| # Values > 7 should be interpreted as negative | ||
| w_unpacked[w_unpacked > 7] -= 16 | ||
|
|
||
| return w_unpacked.contiguous() | ||
|
|
||
|
|
||
| def pack_int4_to_int32_gptq(weight_unpacked: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Pack int8 tensor (with int4 values) into int32 GPTQ format. | ||
|
|
||
| Args: | ||
| weight_unpacked: Shape (N, K) dtype int8 with values in range [-8, 7] | ||
|
|
||
| Returns: | ||
| packed: Shape (N, K/8) dtype int32 | ||
| """ | ||
| n, k = weight_unpacked.shape | ||
| assert k % 8 == 0, "K must be divisible by 8" | ||
|
|
||
| # Convert int4 [-8, 7] to uint4 [0, 15] | ||
| w_uint4 = weight_unpacked.clone() | ||
| w_uint4[w_uint4 < 0] += 16 | ||
| w_uint4 = w_uint4.to(torch.uint8) | ||
|
|
||
| # Pack 2 uint4 into 1 uint8 | ||
| w_packed_uint8 = torch.zeros(n, k // 2, dtype=torch.uint8) | ||
| w_packed_uint8 = (w_uint4[:, 1::2] << 4) | (w_uint4[:, 0::2]) | ||
|
|
||
| # Reshape to int32 | ||
| w_packed_int32 = ( | ||
| w_packed_uint8.view(n, k // 8, 4).view(torch.uint8).view(n, k // 8).view(torch.int32) | ||
| ) | ||
|
|
||
| return w_packed_int32.contiguous() | ||
|
|
||
|
|
||
| def convert_compressed_tensor_to_gptq( | ||
| weight_packed: torch.Tensor, | ||
| weight_scale: torch.Tensor, | ||
| weight_shape: list[int], | ||
| group_size: int = 32, | ||
| ) -> dict[str, torch.Tensor]: | ||
| """ | ||
| Convert compressed-tensors format to GPTQ format for TRT-LLM. | ||
|
|
||
| Args: | ||
| weight_packed: Shape (N, K/8) dtype int32 | ||
| weight_scale: Shape (N, K/group_size) dtype bfloat16 | ||
| weight_shape: [N, K] - original weight shape | ||
| group_size: Quantization group size | ||
|
|
||
| Returns: | ||
| Dictionary with: | ||
| - qweight: Shape (K/8, N) dtype int32 (transposed!) | ||
| - scales: Shape (K/group_size, N) dtype fp16 (transposed!) | ||
| - qzeros: Shape (K/group_size, N/8) dtype int32 (for symmetric, all zeros) | ||
| """ | ||
| n, k = weight_shape | ||
|
|
||
| # TRT-LLM expects weights transposed: (K, N) instead of (N, K) | ||
| # But packed format keeps the packing dimension, so: | ||
| # Source: (N, K/8) -> Target: (K/8, N) | ||
| qweight = weight_packed.t().contiguous() | ||
|
|
||
| # Scales also need transpose: (N, K/group_size) -> (K/group_size, N) | ||
| scales = weight_scale.t().contiguous().to(torch.float16) | ||
|
|
||
| # For symmetric quantization, we use zero as the zero-point | ||
| # GPTQ format expects qzeros in packed format | ||
| # Shape: (K/group_size, N/8) since zeros are also packed 8 per int32 | ||
| num_groups = k // group_size | ||
| # Create zeros tensor - for symmetric quantization, zero-point is 8 (middle of [0,15]) | ||
| # Pack as uint4: each int32 contains 8 nibbles of value 8 | ||
| # Create as bytes first then view as int32 | ||
| qzeros_uint8 = torch.full((num_groups, n // 2), 0x88, dtype=torch.uint8) | ||
| qzeros = qzeros_uint8.view(torch.int32).contiguous() | ||
|
|
||
| return { | ||
| "qweight": qweight, | ||
| "scales": scales, | ||
| "qzeros": qzeros, | ||
| } | ||
|
|
||
|
|
||
| def convert_checkpoint( | ||
| input_dir: str, | ||
| output_dir: str, | ||
| num_shards: int | None = None, | ||
| skip_existing: bool = True, | ||
| ): | ||
| """ | ||
| Convert all shards from compressed-tensors to GPTQ format. | ||
|
|
||
| Args: | ||
| input_dir: Source checkpoint directory | ||
| output_dir: Output checkpoint directory | ||
| num_shards: Number of shards to process (None = all) | ||
| skip_existing: Skip conversion if output shard already exists | ||
| """ | ||
| input_path = Path(input_dir) | ||
| output_path = Path(output_dir) | ||
| output_path.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| # Find all safetensors files | ||
| shard_files = sorted(input_path.glob("model-*.safetensors")) | ||
| if not shard_files: | ||
| raise ValueError(f"No model shards found in {input_dir}") | ||
|
|
||
| if num_shards is not None: | ||
| shard_files = shard_files[:num_shards] | ||
|
|
||
| print(f"Found {len(shard_files)} shards to process") | ||
|
|
||
| # Track weight mapping for index file | ||
| new_weight_map = {} | ||
|
|
||
| # Load and convert each shard | ||
| for shard_idx, shard_file in enumerate(tqdm(shard_files, desc="Processing shards")): | ||
| shard_name = shard_file.name | ||
| output_file = output_path / shard_name | ||
|
|
||
| # Check if output already exists | ||
| if skip_existing and output_file.exists(): | ||
| print(f"\n⏭️ Skipping {shard_name} (already exists)") | ||
| # Still need to build the weight_map from existing file | ||
| with safetensors.safe_open(str(output_file), framework="pt", device="cpu") as f: | ||
| for key in f: | ||
| new_weight_map[key] = shard_name | ||
| continue | ||
|
|
||
| print(f"\n🔄 Converting {shard_file.name}...") | ||
|
|
||
| # Load source shard | ||
| source_tensors = {} | ||
| with safetensors.safe_open(str(shard_file), framework="pt", device="cpu") as f: | ||
| for key in f: | ||
| source_tensors[key] = f.get_tensor(key) | ||
|
|
||
| # Convert tensors | ||
| output_tensors = {} | ||
|
|
||
| for key, tensor in tqdm(source_tensors.items(), desc="Converting tensors", leave=False): | ||
| if key.endswith(".weight_packed"): | ||
| # This is a quantized weight - convert to GPTQ format | ||
| base_key = key[: -len(".weight_packed")] | ||
| scale_key = base_key + ".weight_scale" | ||
| shape_key = base_key + ".weight_shape" | ||
|
|
||
| if scale_key in source_tensors and shape_key in source_tensors: | ||
| weight_shape = source_tensors[shape_key].tolist() | ||
|
|
||
| # Convert to GPTQ format | ||
| gptq_tensors = convert_compressed_tensor_to_gptq( | ||
| weight_packed=tensor, | ||
| weight_scale=source_tensors[scale_key], | ||
| weight_shape=weight_shape, | ||
| group_size=32, | ||
| ) | ||
|
|
||
| # Save with GPTQ naming convention and track in weight_map | ||
| qweight_key = base_key + ".qweight" | ||
| scales_key = base_key + ".scales" | ||
| qzeros_key = base_key + ".qzeros" | ||
|
|
||
| output_tensors[qweight_key] = gptq_tensors["qweight"] | ||
| output_tensors[scales_key] = gptq_tensors["scales"] | ||
| output_tensors[qzeros_key] = gptq_tensors["qzeros"] | ||
|
|
||
| new_weight_map[qweight_key] = shard_name | ||
| new_weight_map[scales_key] = shard_name | ||
| new_weight_map[qzeros_key] = shard_name | ||
| else: | ||
| print(f"Warning: Missing scale or shape for {key}") | ||
|
|
||
| elif key.endswith((".weight_scale", ".weight_shape")): | ||
| # Skip these as they're handled above | ||
| continue | ||
| else: | ||
| # Keep non-quantized tensors as-is | ||
| output_tensors[key] = tensor | ||
| new_weight_map[key] = shard_name | ||
|
|
||
| # Save converted shard | ||
| safetensors.torch.save_file(output_tensors, str(output_file)) | ||
| print(f"✅ Saved to {output_file}") | ||
|
|
||
| # Copy config.json and update quantization settings | ||
| config_file = input_path / "config.json" | ||
| if config_file.exists(): | ||
| with open(config_file) as f: | ||
| config = json.load(f) | ||
|
|
||
| # Remove HuggingFace quantization_config if present | ||
| config.pop("quantization_config", None) | ||
|
|
||
| # Add TRT-LLM native quantization config for GPTQ | ||
| config["quantization"] = { | ||
| "quant_algo": "W4A16_GPTQ", # TRT-LLM's enum value | ||
| "group_size": 32, | ||
| "has_zero_point": True, # GPTQ uses asymmetric quantization | ||
| "pre_quant_scale": False, # No pre-quantization scaling | ||
| } | ||
|
|
||
| output_config_file = output_path / "config.json" | ||
| with open(output_config_file, "w") as f: | ||
| json.dump(config, f, indent=2) | ||
| print(f"\nSaved config to {output_config_file}") | ||
|
|
||
| # Generate new safetensors index file | ||
| index_data = { | ||
| "metadata": { | ||
| "total_size": sum( | ||
| (output_path / shard_file.name).stat().st_size for shard_file in shard_files | ||
| ) | ||
| }, | ||
| "weight_map": new_weight_map, | ||
| } | ||
|
|
||
| index_file = output_path / "model.safetensors.index.json" | ||
| with open(index_file, "w") as f: | ||
| json.dump(index_data, f, indent=2) | ||
| print(f"\nGenerated index file: {index_file}") | ||
| print(f" Total tensors: {len(new_weight_map)}") | ||
|
|
||
| # Copy other necessary files | ||
| import shutil | ||
|
|
||
| # JSON files (tokenizer and generation config) | ||
| for file in [ | ||
| "generation_config.json", | ||
| "tokenizer.json", | ||
| "tokenizer_config.json", | ||
| "special_tokens_map.json", | ||
| "vocab.json", | ||
| "merges.txt", | ||
| ]: | ||
| src = input_path / file | ||
| if src.exists(): | ||
| shutil.copy(src, output_path / file) | ||
| print(f"Copied: {file}") | ||
|
|
||
| # Python files (model architecture, custom tokenizers) | ||
| for file in ["configuration_deepseek.py", "modeling_deepseek.py", "tokenization_kimi.py"]: | ||
| src = input_path / file | ||
| if src.exists(): | ||
| shutil.copy(src, output_path / file) | ||
| print(f"Copied: {file}") | ||
|
|
||
| # Tokenizer model files | ||
| for file in ["tiktoken.model", "tokenizer.model", "sentencepiece.model"]: | ||
| src = input_path / file | ||
| if src.exists(): | ||
| shutil.copy(src, output_path / file) | ||
| print(f"Copied: {file}") | ||
|
|
||
| # Template files | ||
| for file in ["chat_template.jinja", "chat_template.json"]: | ||
| src = input_path / file | ||
| if src.exists(): | ||
| shutil.copy(src, output_path / file) | ||
| print(f"Copied: {file}") | ||
|
|
||
| print(f"\n✓ Conversion complete! Output saved to {output_dir}") | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser( | ||
| description="Convert Kimi-K2 checkpoint to TRT-LLM GPTQ format" | ||
| ) | ||
| parser.add_argument( | ||
| "--input-dir", | ||
| type=str, | ||
| default="/home/scratch.omniml_data_1/models/Kimi-K2-Thinking", | ||
| help="Input checkpoint directory with compressed-tensors format", | ||
| ) | ||
| parser.add_argument( | ||
| "--output-dir", | ||
| type=str, | ||
| default="/home/scratch.omniml_data_2/zhiyuc/checkpoints/Kimi-K2-Thinking-GPTQ", | ||
| help="Output directory for GPTQ format checkpoint", | ||
| ) | ||
| parser.add_argument( | ||
| "--num-shards", | ||
| type=int, | ||
| default=None, | ||
| help="Number of shards to convert (default: all)", | ||
| ) | ||
| parser.add_argument( | ||
| "--no-skip-existing", | ||
| action="store_true", | ||
| help="Re-convert shards even if they already exist (default: skip existing)", | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| convert_checkpoint( | ||
| input_dir=args.input_dir, | ||
| output_dir=args.output_dir, | ||
| num_shards=args.num_shards, | ||
| skip_existing=not args.no_skip_existing, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Edwardf0t1 do you feel we need to create a kimi_k2 specific folder under the examples?
Also please add readme on how this script should be used and how the output of this script can be used.