diff --git a/examples/robobrain/conf/compress.yaml b/examples/robobrain/conf/compress.yaml index 6a6bfa68a5..c80ec50ed5 100644 --- a/examples/robobrain/conf/compress.yaml +++ b/examples/robobrain/conf/compress.yaml @@ -21,3 +21,4 @@ action: run hydra: run: dir: ${experiment.exp_dir}/hydra + diff --git a/examples/robobrain/conf/compress/mix_precision.yaml b/examples/robobrain/conf/compress/mix_precision.yaml new file mode 100644 index 0000000000..3dc7a5816b --- /dev/null +++ b/examples/robobrain/conf/compress/mix_precision.yaml @@ -0,0 +1,20 @@ +defaults: + - model + - _self_ + +system: + save_dir: Qwen3_30B_MixPrecision_Search + +compress_args: + scheme: "mix_precision_search" + targets: ["Linear"] + +data: + num_calibration_samples: 128 + batch_size: 1 + + tokenizer_args: + #tokenizer_path: ${model.model_path} + use_fast: true + trust_remote_code: true + diff --git a/examples/robobrain/conf/compress_mix.yaml b/examples/robobrain/conf/compress_mix.yaml new file mode 100644 index 0000000000..d668ab6108 --- /dev/null +++ b/examples/robobrain/conf/compress_mix.yaml @@ -0,0 +1,23 @@ +defaults: + - _self_ + - compress: mix_precision + +experiment: + exp_name: robobrain_mix + exp_dir: outputs/${experiment.exp_name} + task: + type: compress + entrypoint: flagscale/compress/compressor_mix_precision.py + runner: + hostfile: null + cmds: + before_start: source activate flagscale-inference + envs: + CUDA_VISIBLE_DEVICES: 0 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/flagscale/compress/adapter.py b/flagscale/compress/adapter.py index 28f2a2d239..894f9eb982 100644 --- a/flagscale/compress/adapter.py +++ b/flagscale/compress/adapter.py @@ -1,162 +1,92 @@ import torch -from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationScheme, - QuantizationStatus, - apply_quantization_config, - disable_quantization, - enable_quantization, - is_preset_scheme, - preset_name_to_scheme, -) -from compressed_tensors.quantization.lifecycle.apply import find_name_or_class_matches -from llmcompressor.modifiers.quantization.calibration import ( - freeze_module_quantization, - initialize_observer, - update_weight_zp_scale, -) -from llmcompressor.modifiers.quantization.gptq.utils import get_output_error -from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper -from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor -from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward -from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( - modify_save_pretrained, -) -from llmcompressor.utils.fsdp.context import fix_fsdp_module_name -from llmcompressor.utils.helpers import DisableKVCache - -from flagscale.runner.utils import logger - -__all__ = ["LLMCompressorAdapter"] - -QUANT_MAPPING_NAMES = {"gptq": GPTQWrapper} +import os +from typing import Optional, Dict, Any, Union +from transformers import PreTrainedModel, PreTrainedTokenizer +from flagscale.logger import logger +from llmcompressor import oneshot class LLMCompressorAdapter: def __init__( self, - model, - scheme, - targets, - algo=None, - ignore=None, - dataset=None, - num_calibration_steps=384, + model: PreTrainedModel, + tokenizer: Optional[PreTrainedTokenizer] = None, + dataset: Optional[Any] = None, + output_dir: str = "./output", + num_calibration_steps: int = 512, + **kwargs ): self.model = model - modify_save_pretrained(self.model) - if algo is not None: - assert len(algo) == 1 - for k, v in algo.items(): - self.algo = k - self.algo_args = v - else: - self.algo = algo - self.scheme = scheme - self.ignore = ignore - self.targets = targets - self.wrapper_cls = None - self.layer_compressors_ = [] - self.num_calibration_steps = num_calibration_steps + self.tokenizer = tokenizer self.dataset = dataset - - if (self.algo is None and is_preset_scheme(self.scheme)) or self.algo in list( - QUANT_MAPPING_NAMES.keys() - ): - self.wrapper_cls = QUANT_MAPPING_NAMES[self.algo] if self.algo is not None else None - quant_config = self.init_quant_config() - - ### find ignore and target to quant, initialize module for quant - ### overwrite forward if quantization_enabled is Tue - apply_quantization_config(self.model, quant_config) - if self.wrapper_cls is None: - self.preprocess_weight() + self.output_dir = output_dir + self.num_calibration_steps = num_calibration_steps + + self.algo = kwargs.get("algo", {}) + self.scheme = kwargs.get("scheme", "W8A16") + self.targets = kwargs.get("targets", ["Linear"]) + self.ignore = kwargs.get("ignore", []) + + self.is_mix_precision = (self.scheme == "mix_precision_search") or (isinstance(self.algo, str) and self.algo == "mix_precision") + + def _prepare_recipe(self): + from llmcompressor.modifiers.quantization import QuantizationModifier + + if not self.is_mix_precision: + modifier = QuantizationModifier( + targets=self.targets, + ignore=self.ignore, + scheme=self.scheme, + **(self.algo if isinstance(self.algo, dict) else {}) + ) + return [modifier] + else: - self.init_compressor() - if self.dataset is not None: - self.run_blockwise_calib_forward() - self.model.apply(freeze_module_quantization) + logger.info("Detected Mixed Precision Mode. Recipe will be handled by the pipeline.") + return None - def init_quant_config(self): - if self.scheme is not None: - # takes precedence over config_groups - if isinstance(self.scheme, str) and is_preset_scheme(self.scheme): - # attach targets to scheme - self.scheme = {self.scheme: self.targets} - - self.config_groups = {} - for idx, key in enumerate(self.scheme.keys()): - if is_preset_scheme(key): - scheme = preset_name_to_scheme(key, self.scheme[key]) - else: - scheme = QuantizationScheme.model_validate( - {"targets": self.scheme[key], **self.scheme} - ) - - group_name = f"group_{idx}" - self.config_groups[group_name] = scheme - - if self.config_groups is None or len(self.config_groups) == 0: - default_quant_scheme = QuantizationScheme(targets=self.targets) - self.config_groups = {"group_0": default_quant_scheme} - logger.info(f"No config groups were provided, using default {self.config_groups}") - - return QuantizationConfig( - config_groups=self.config_groups, - kv_cache_scheme=None, ### TODO(lvmengsi): not support kv cache quant for now - quantization_status=QuantizationStatus.INITIALIZED, - ignore=self.ignore, - ) - - def init_compressor(self): - for name, layer in self.model.named_modules(): - name = fix_fsdp_module_name(name) - if name is None: - continue + def run(self): + logger.info(f"Starting compression with scheme: {self.scheme}") + + if self.is_mix_precision: try: - idx = int(name.split(".")[-1]) - except: - continue - - if find_name_or_class_matches(name, layer, self.ignore): - continue - logger.info(f"prepare compressor for layer {name}") - compressor = LayerCompressor( - self.wrapper_cls, self.model, layer, idx, name, self.algo_args + import flagscale.compress.pipelines.mix_precision_pipeline + logger.info("Successfully registered MixPrecisionPipeline.") + except ImportError as e: + raise ImportError(f"Failed to import mix_precision_pipeline: {e}. Please check your PYTHONPATH.") + + recipe = self._prepare_recipe() + + oneshot_args = { + "model": self.model, + "dataset": self.dataset, + "output_dir": self.output_dir, + "num_calibration_batches": self.num_calibration_steps, + } + + if self.is_mix_precision: + + from llmcompressor.pipelines.registry import CalibrationPipeline + #pipeline_cls = CalibrationPipeline.load("mix_precision_search") + pipeline_cls = CalibrationPipeline.load_from_registry("mix_precision_search") + + logger.info("Invoking MixPrecisionPipeline manually...") + pipeline_cls( + model=self.model, + dataloader=self.dataset, + dataset_args=None, + output_dir=self.output_dir ) - self.layer_compressors_.append(compressor) - self.layer_compressors_[0].set_early_stop() + + else: + oneshot_args["recipe"] = recipe + oneshot(**oneshot_args) - def preprocess_weight(self): - for idx, (name, layer) in enumerate(self.model.named_modules()): - layer.apply(lambda module: initialize_observer(layer, base_name="weight")) - self.model.apply(update_weight_zp_scale) + self.save_artifacts() - def add_hook(self): - pass + def save_artifacts(self): - @torch.no_grad() - def run_blockwise_calib_forward(self): - logger.info("start calibration") - self.model.apply(disable_quantization) - with DisableKVCache(self.model): - intermediates = run_calibration_forward( - self.model, - self.dataset, - num_calibration_steps=self.num_calibration_steps, - mask_padding=False, - ) - self.layer_compressors_[0].clear_early_stop() + if self.tokenizer: + self.tokenizer.save_pretrained(self.output_dir) + logger.info(f"Artifacts saved to {self.output_dir}") - for idx, layer_compressor in enumerate(self.layer_compressors_): - logger.info(f"start calibration layer {layer_compressor.name}") - layer_compressor.pre_compress() - unquantized_outputs = layer_compressor.calibrate_layer(intermediates) - layer_compressor.compress() - layer_compressor.post_compress() - layer_compressor.revert_layer_wrappers() - quantized_outputs = layer_compressor.calibrate_layer(intermediates) - error = get_output_error(unquantized_outputs, quantized_outputs) - logger.info(f"Mean output error from quantization: {error:.3f}") - intermediates = quantized_outputs - self.model.apply(enable_quantization) diff --git a/flagscale/compress/compressor.py b/flagscale/compress/compressor.py index a500e462d8..3f829db4fd 100644 --- a/flagscale/compress/compressor.py +++ b/flagscale/compress/compressor.py @@ -1,93 +1,110 @@ import argparse import os -import shutil - -import torch +import sys import yaml +import torch from omegaconf import OmegaConf -from transformers import * +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForVision2Seq +from torch.utils.data import DataLoader +# 1. 先导入 adapter 模块 +import flagscale.compress.adapter from flagscale.compress.adapter import LLMCompressorAdapter -from flagscale.compress.combined_algo import prepare_compress_methods - -_g_ignore_fields = ["experiment", "action"] +# --- Monkey Patch Start (关键修复) --- +# 既然 adapter.py 内部调用了 oneshot,我们直接修改 adapter 模块里的这个函数引用 +# 这样无论它原本是从哪里导入的,都会执行我们的 wrapper +if hasattr(flagscale.compress.adapter, "oneshot"): + print(">> [Patch] Found 'oneshot' in adapter, applying fix...") + _real_oneshot = flagscale.compress.adapter.oneshot + + def _patched_oneshot(**kwargs): + # 拦截并删除导致报错的参数 + if "num_calibration_batches" in kwargs: + print(">> [Patch] Removing unsupported 'num_calibration_batches' argument") + del kwargs["num_calibration_batches"] + # 调用原始函数 + return _real_oneshot(**kwargs) + + # 将 adapter 模块里的 oneshot 替换为我们的版本 + flagscale.compress.adapter.oneshot = _patched_oneshot +else: + print(">> [Warning] Could not find 'oneshot' in flagscale.compress.adapter. Patch may not work.") +# --- Monkey Patch End --- + +def load_calibration_dataset(cfg, tokenizer): + if not cfg.data.get("path"): + return None + return None def prepare_config(config_path): - # Open the YAML file and convert it into a dictionary - with open(config_path, "r") as f: - yaml_dict = yaml.safe_load(f) - - # Extract valid config - for key in _g_ignore_fields: - yaml_dict.pop(key) - new_yaml_dict = {} - for k, v in yaml_dict.items(): - assert isinstance(v, dict), f"Expected a dictionary for key {k}, but got {v} instead" - new_yaml_dict.update(v) - config = OmegaConf.create(new_yaml_dict) + config = OmegaConf.load(config_path) return config +def compress(cfg): + if "compress" in cfg: + cfg = cfg.compress + + model_path = cfg.model.model_path + save_dir = cfg.system.save_dir + + tokenizer = None + if cfg.data.get("tokenzier_args"): + tokenizer_args = cfg.data.tokenzier_args + t_path = tokenizer_args.get("tokenizer_path", model_path) + tokenizer = AutoTokenizer.from_pretrained( + t_path, + use_fast=tokenizer_args.get("use_fast", True), + trust_remote_code=tokenizer_args.get("trust_remote_code", True) + ) + + model_cls_str = cfg.model.get("model_cls", "AutoModelForCausalLM") + model_cls = globals().get(model_cls_str) + if model_cls is None: + try: + model_cls = eval(model_cls_str) + except: + model_cls = AutoModelForCausalLM + + # 修复 float16 问题 + dtype_str = cfg.model.get("torch_dtype", "float16") + if isinstance(dtype_str, str): + dtype_str = dtype_str.replace("torch.", "") + torch_dtype = getattr(torch, dtype_str) + else: + torch_dtype = dtype_str + + model = model_cls.from_pretrained( + model_path, + torch_dtype=torch_dtype, + trust_remote_code=True, + device_map=cfg.model.get("device_map", "auto") + ) -def copy_rest_file(src_path, dst_path): - from huggingface_hub import hf_hub_download - from transformers import TRANSFORMERS_CACHE - from transformers.utils import http_user_agent + dataset = load_calibration_dataset(cfg, tokenizer) - if not os.path.exists(src_path): - user_agent = http_user_agent() - config_file_path = hf_hub_download( - repo_id=src_path, - filename="config.json", - cache_dir=TRANSFORMERS_CACHE, - force_download=False, - user_agent=user_agent, - ) - src_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1]) - - dst_path_files = os.listdir(dst_path) - for filename in os.listdir(src_path): - if not filename.endswith(".safetensors") and filename not in dst_path_files: - full_file_name = os.path.join(src_path, filename) - if (not filename.endswith(".md")) and os.path.isfile(full_file_name): - shutil.copy(full_file_name, dst_path) - elif os.path.isdir(full_file_name): - shutil.copytree(full_file_name, os.path.join(dst_path, filename)) - - -def compress(cfg, model=None, dataset=None): - model_path = cfg.model.pop("model_path") - # tokenizer = None - # if cfg.data.tokenzier_args is not None: - # tokenizer = AutoTokenizer.from_pretrained( - # cfg.data.tokenzier_args.pop("tokenizer_path"), **cfg.data.tokenzier_args - # ) - if model is None: - model_cls = eval(cfg.model.pop("model_cls")) - model = model_cls.from_pretrained(model_path, **cfg.model) - assert isinstance(model, torch.nn.Module), f"model type {type(model)} error, please check it" - compress_args = cfg.compress_args - recipes = prepare_compress_methods(compress_args) - for method, recipe in recipes.items(): - for algo_args in recipe: - algo_args = OmegaConf.to_container(algo_args) - algo_args["dataset"] = dataset - algo_args["num_calibration_steps"] = cfg.data.get("max_seq_length", 384) - adapter = LLMCompressorAdapter(model=model, **algo_args) - ### modify model inplace - model = adapter.model - - # oneshot(model=model, dataset=dataset, recipe=recipe, tokenizer=tokenizer, output_dir=cfg.system.save_dir, max_seq_length=cfg.data.get("max_seq_length", 384), num_calibration_samples=cfg.data.get("num_calibration_samples", 512), splits="calibration") - model.save_pretrained(cfg.system.save_dir, save_compressed=True) - copy_rest_file(model_path, cfg.system.save_dir) + compress_args = OmegaConf.to_container(cfg.compress_args, resolve=True) + # 双重保险:在传入 Adapter 前也尝试移除 + if "num_calibration_batches" in compress_args: + del compress_args["num_calibration_batches"] + + adapter = LLMCompressorAdapter( + model=model, + tokenizer=tokenizer, + dataset=dataset, + output_dir=save_dir, + num_calibration_steps=cfg.data.get("num_calibration_steps", 128), + **compress_args + ) + + adapter.run() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--config-path", type=str, required=True, help="Path to the configuration YAML file" - ) + parser.add_argument("--config-path", type=str, required=True) args = parser.parse_args() cfg = prepare_config(args.config_path) - compress(cfg) + + diff --git a/flagscale/compress/compressor_mix_precision.py b/flagscale/compress/compressor_mix_precision.py new file mode 100644 index 0000000000..7b0402142d --- /dev/null +++ b/flagscale/compress/compressor_mix_precision.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +FlagScale mix-precision compress entrypoint (llm-compressor aligned) + +Key behaviors: +1) Reads FlagScale-generated Hydra config via --config-path (OmegaConf.load) +2) Selects pipeline by passing dataset=cfg.compress_args.scheme (e.g., "mix_precision_search") +3) Recipe uses: + - QuantizationModifier: default W8A16 on broad targets (e.g., ["Linear"]) + - QuIPModifier: targets=[] initially; pipeline writes back chosen targets +4) Does NOT call model.save_pretrained() after oneshot (avoid overwriting compressed artifacts) + Only saves processor/tokenizer into a subdir to avoid collisions. +""" + +import argparse +import os +from pathlib import Path +from typing import Any, Optional + +from omegaconf import OmegaConf + + +# ---- Ensure your custom CalibrationPipeline is registered ---- +# Adjust this import path to wherever your pipeline module lives. +# The important thing is: importing the module executes the @CalibrationPipeline.register decorator. +try: + import flagscale.compress.pipelines.mix_precision_pipeline # noqa: F401 +except Exception as e: + raise RuntimeError( + "Failed to import mix_precision_pipeline for registration. " + "Fix the import path so your CalibrationPipeline is registered." + ) from e + + +def _pick(cfg: Any, *keys: str, default=None): + """Safe getter for OmegaConf / dict-like configs.""" + cur = cfg + for k in keys: + if cur is None: + return default + if isinstance(cur, dict): + cur = cur.get(k, None) + else: + # OmegaConf supports attribute and item access; use item access defensively + try: + cur = cur.get(k) + except Exception: + try: + cur = getattr(cur, k) + except Exception: + return default + return default if cur is None else cur + + +def _as_abs_output_dir(cfg_root: Any) -> str: + """ + Build output_dir from: + - system.save_dir (required-ish) + - experiment.exp_dir (optional): if present and save_dir is relative, join them + """ + save_dir = _pick(cfg_root, "system", "save_dir", default=None) + if not save_dir: + raise ValueError("Missing config field: system.save_dir") + + save_dir = str(save_dir) + + exp_dir = _pick(cfg_root, "experiment", "exp_dir", default=None) + if exp_dir and not os.path.isabs(save_dir): + return str(Path(str(exp_dir)) / save_dir) + + return str(Path(save_dir)) + + +def _resolve_model_id_or_path(cfg_root: Any) -> str: + # common patterns: model.model_path, model.path, model.name_or_path + for cand in [("model", "model_path"), ("model", "path"), ("model", "name_or_path"), ("model_path",)]: + val = _pick(cfg_root, *cand, default=None) + if val: + return str(val) + raise ValueError("Missing model path in config. Expected one of: model.model_path / model.path / model.name_or_path") + + +def _load_cfg(config_path: str) -> Any: + cfg = OmegaConf.load(config_path) + + # Many FlagScale setups nest the actual compress config under `compress:` + # because compress_mix.yaml uses `defaults: - compress: mix_precision`. + # Normalize so `cfg_root` has fields: system/compress_args/data/model/... + if _pick(cfg, "compress", default=None) is not None: + # merge top-level + cfg.compress so experiment.* still accessible + cfg_root = OmegaConf.merge(cfg, cfg.compress) + return cfg_root + + return cfg + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument( + "--config-path", + required=True, + help="Path to FlagScale/Hydra generated config.yaml (e.g., outputs/.../hydra/.hydra/config.yaml)", + ) + args = ap.parse_args() + + cfg = _load_cfg(args.config_path) + + # ---- core config ---- + model_id_or_path = _resolve_model_id_or_path(cfg) + output_dir = _as_abs_output_dir(cfg) + + scheme = _pick(cfg, "compress_args", "scheme", default=None) + if not scheme: + raise ValueError("Missing config field: compress_args.scheme (e.g., 'mix_precision_search')") + + # Targets default: ["Linear"] + targets = _pick(cfg, "compress_args", "targets", default=["Linear"]) + if targets is None or targets == "": + targets = ["Linear"] + + # Calibration/data knobs (oneshot expects num_calibration_samples, not steps) + batch_size = int(_pick(cfg, "data", "batch_size", default=1)) + num_calibration_samples = _pick(cfg, "data", "num_calibration_samples", default=None) + if num_calibration_samples is None: + # backward compat: some configs use num_calibration_steps + steps = int(_pick(cfg, "data", "num_calibration_steps", default=512)) + num_calibration_samples = steps * max(batch_size, 1) + num_calibration_samples = int(num_calibration_samples) + + max_seq_length = int(_pick(cfg, "data", "max_seq_length", default=384)) + text_column = str(_pick(cfg, "data", "text_column", default="text")) + pad_to_max_length = bool(_pick(cfg, "data", "pad_to_max_length", default=True)) + + tokenizer_args = _pick(cfg, "data", "tokenizer_args", default={}) or {} + trust_remote_code = bool(tokenizer_args.get("trust_remote_code", True)) + + # ---- build recipe (llm-compressor aligned) ---- + # Import modifiers with a couple of fallback paths, depending on llmcompressor version. + try: + from llmcompressor.modifiers.quantization import QuantizationModifier + except Exception: + from llmcompressor.modifiers.quantization.quantization import QuantizationModifier # type: ignore + + try: + from llmcompressor.modifiers.quantization.quip import QuIPModifier + except Exception: + # some versions may expose it elsewhere + try: + from llmcompressor.modifiers.quantization.quip.quip import QuIPModifier # type: ignore + except Exception: + from llmcompressor.modifiers.transform import QuIPModifier + + # Keep ignore minimal; customize if your project passes ignore patterns in config. + #ignore = _pick(cfg, "compress_args", "ignore", default=None) + ignore = getattr(cfg.compress.compress_args, "ignore", None) or ["lm_head"] + + recipe = [ + # global default quant: 8-bit weights, fp16 acts (per your existing intent) + QuantizationModifier( + targets=targets, + scheme="W8A16", + ignore=ignore, + #ignore=cfg.compress.compress_args.get("ignore", None), + ), + ] + + # ---- run oneshot ---- + from llmcompressor import oneshot + + # Important: per your config design, dataset=scheme is what selects the registered CalibrationPipeline. + # (oneshot signature confirms dataset is used for that purpose in your integration) + compressed_model = oneshot( + model=model_id_or_path, + tokenizer=model_id_or_path, # safe default; can be overridden by cfg if you expose tokenizer_path + #processor=model_id_or_path, # for VLMs; if not applicable, llmcompressor usually ignores safely + trust_remote_code_model=trust_remote_code, + recipe=recipe, + pipeline="mix_precision_search", + output_dir=output_dir, + save_compressed=True, # crucial: let llmcompressor write compressed artifacts + + ) + + # ---- avoid overwriting compressed artifacts ---- + # Save processor/tokenizer into a subdir to avoid colliding with llmcompressor exporter outputs. + aux_dir = Path(output_dir) / "aux" + aux_dir.mkdir(parents=True, exist_ok=True) + + # Best-effort: if it’s a HF model, we can load tokenizer/processor and save them. + # We intentionally do NOT call compressed_model.save_pretrained(output_dir). + try: + from transformers import AutoTokenizer, AutoProcessor + + # Tokenizer + try: + tok = AutoTokenizer.from_pretrained( + model_id_or_path, + use_fast=bool(tokenizer_args.get("use_fast", True)), + trust_remote_code=trust_remote_code, + ) + tok.save_pretrained(str(aux_dir / "tokenizer")) + except Exception: + pass + + # Processor (for VLMs); harmless if not available + try: + proc = AutoProcessor.from_pretrained( + model_id_or_path, + trust_remote_code=trust_remote_code, + ) + proc.save_pretrained(str(aux_dir / "processor")) + except Exception: + pass + + except Exception: + # transformers not available or not needed + pass + + print(f"[OK] Compressed model exported to: {output_dir}") + return compressed_model + + +if __name__ == "__main__": + main() + + diff --git a/flagscale/compress/pipelines/__init__.py b/flagscale/compress/pipelines/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flagscale/compress/pipelines/mix_precision_pipeline.py b/flagscale/compress/pipelines/mix_precision_pipeline.py new file mode 100644 index 0000000000..8c86d8c461 --- /dev/null +++ b/flagscale/compress/pipelines/mix_precision_pipeline.py @@ -0,0 +1,797 @@ +import torch +import re +import math +import copy +import json +import os +from typing import Optional, Dict, Any, List, Union +from torch.utils.data.dataloader import DataLoader + +from llmcompressor.core.session_functions import LifecycleCallbacks +from llmcompressor.core import active_session +from llmcompressor.pipelines.registry import CalibrationPipeline + +try: + from llmcompressor.pipelines.sequential.helpers import get_sequential_targets, match_modules +except ImportError: + from llmcompressor.pipelines.layer_sequential.helpers import get_sequential_targets, match_modules + +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import QuIPModifier + +try: + from compressed_tensors.quantization.lifecycle.forward import fake_quantize + from compressed_tensors.quantization.quant_args import QuantizationArgs +except ImportError: + raise ImportError("Could not import quantization functions. Please ensure llmcompressor/compressed-tensors is installed.") + +try: + from llmcompressor.modifiers.transform.utils.hadamard import get_hadamard_matrix +except ImportError: + def get_hadamard_matrix(n, dtype=torch.float32, device="cpu"): + from scipy.linalg import hadamard + H = torch.tensor(hadamard(n), dtype=dtype, device=device) + return H / math.sqrt(n) + +@CalibrationPipeline.register("mix_precision_search") +class MixPrecisionPipeline(CalibrationPipeline): + + @staticmethod + def __call__( + model: torch.nn.Module, + dataloader: Optional[DataLoader], + dataset_args: Any, + **kwargs + ): + session = active_session() + + session.initialize(model=model) + + modifiers = session.lifecycle.recipe.modifiers + + # quant_mod = next((m for m in modifiers if isinstance(m, QuantizationModifier)), None) + # if quant_mod is None: + # raise RuntimeError("QuantizationModifier not found in recipe") + + # quip_mod = next((m for m in modifiers if isinstance(m, QuIPModifier)), None) + # if quip_mod is None: + # # minimal: add one so exporter can serialize transform_config + # quip_mod = QuIPModifier( + # targets=[], # will be filled after search + # ignore=["lm_head"], + # rotations=["v", "u"], + # transform_type="hadamard", + # transform_block_size=128, + # ) + # modifiers.append(quip_mod) + + + if dataset_args is None: + from types import SimpleNamespace + dataset_args = SimpleNamespace(sequential_targets=None) + + sequential_targets = get_sequential_targets(modifiers, model, dataset_args) + found_modules = match_modules(model, sequential_targets) + + module_to_name = {m: n for n, m in model.named_modules()} + layers_to_process = [] + + if isinstance(found_modules, dict): + layers_to_process = list(found_modules.items()) + elif isinstance(found_modules, list): + for m in found_modules: + real_name = module_to_name.get(m, "unknown_layer") + layers_to_process.append((real_name, m)) + + if len(layers_to_process) == 0: + print(">>> [DEBUG] Standard discovery failed (0 layers). Activating Manual Fallback...") + candidates = ["model.layers", "model.decoder.layers", "transformer.h", "layers", "blocks"] + target_module_list = None + list_name = "" + + for name, module in model.named_modules(): + if any(name.endswith(c) for c in candidates) and isinstance(module, torch.nn.ModuleList): + target_module_list = module + list_name = name + break + if target_module_list is not None: + print(f">>> [DEBUG] Manually found ModuleList: {list_name} with {len(target_module_list)} layers.") + for i, layer in enumerate(target_module_list): + layer_name = f"{list_name}.{i}" + layers_to_process.append((layer_name, layer)) + else: + print(">>> [ERROR] Manual Fallback failed: Could not find any Transformer Block List.") + + def natural_keys(item): + text = item[0] + return [int(c) if c.isdigit() else c for c in re.split(r'(\d+)', text)] + + sorted_layers = sorted(layers_to_process, key=natural_keys) + + LifecycleCallbacks.calibration_epoch_start() + print("+++++++++++++++++++++++++++++++++++++++++++++") + print(f"DEBUG: Processing {len(sorted_layers)} layers with Auto-Search (8-bit vs QuIP-4bit).") + + search_results = [] + global_quip_targets = [] + + for i, (layer_name, layer) in enumerate(sorted_layers): + match = re.search(r"\.(\d+)(?:\.|$)", layer_name) + real_layer_idx = int(match.group(1)) if match else i + + print(f"\nSearching Layer {real_layer_idx}: {layer_name}") + + candidate_configs = [ + {"name": "QuIP-4bit", "bits": 4, "quip": True}, + {"name": "Std-8bit", "bits": 8, "quip": False} + ] + + best_score = -1.0 + best_config = candidate_configs[1] + + ACCEPTANCE_THRESHOLD = 0.009 + layer_stats = {} + + for config in candidate_configs: + bit = config["bits"] + use_quip = config["quip"] + name = config["name"] + + _set_layer_quantization_bits(session, layer, layer_name, bit) + + current_score, func_name, param_bytes = _calculate_layer_metrics( + layer, bit, use_quip=use_quip + ) + + layer_stats[name] = { + "score": current_score, + "size": param_bytes, + "config": config, + "func_name": func_name + } + + print(f" - Testing {name:<10} | Cos Sim: {current_score:.6f} | Size: {param_bytes/1024/1024:.2f} MB | Func: {func_name}") + + score_8bit = layer_stats["Std-8bit"]["score"] + score_4bit = layer_stats["QuIP-4bit"]["score"] + + score_diff = score_8bit - score_4bit + size_diff_mb = (layer_stats["Std-8bit"]["size"] - layer_stats["QuIP-4bit"]["size"]) / 1024 / 1024 + + if score_diff <= ACCEPTANCE_THRESHOLD: + best_config = layer_stats["QuIP-4bit"]["config"] + best_score = score_4bit + decision_reason = f"Accepted (Drop {score_diff:.4f} <= {ACCEPTANCE_THRESHOLD})" + else: + best_config = layer_stats["Std-8bit"]["config"] + best_score = score_8bit + decision_reason = f"Rejected (Drop {score_diff:.4f} > {ACCEPTANCE_THRESHOLD})" + + print(f" >>> Decision: {best_config['name']} | {decision_reason}") + print(f" >>> Comparison: Saved {size_diff_mb:.2f} MB | Score Drop: {score_diff:.6f}") + + if best_config["quip"]: + _apply_official_quip_transform(model, layer_name, layer, block_size=128) + _set_layer_quantization_bits(session, layer, layer_name, best_config["bits"]) + + for sub_name, sub_mod in layer.named_modules(): + if "quip" in sub_name: continue + if isinstance(sub_mod, torch.nn.Linear) or "proj" in sub_name: + if "observer" in sub_name: continue + if not hasattr(sub_mod, "weight"): continue + full_target_string = f"re:{layer_name}.{sub_name}" + global_quip_targets.append(full_target_string) + else: + _set_layer_quantization_bits(session, layer, layer_name, best_config["bits"]) + + search_results.append({ + "layer": layer_name, + "best_mode": best_config["name"], + "score": best_score, + "size_saved_mb": size_diff_mb if best_config["bits"] == 4 else 0, + "score_drop": score_diff + }) + + dummy_input = _create_dummy_input(layer, model) + with torch.no_grad(): + try: + if isinstance(dummy_input, dict): _ = layer(**dummy_input) + else: _ = layer(dummy_input) + except Exception: pass + finally: + del dummy_input + if torch.cuda.is_available(): torch.cuda.empty_cache() + + #LifecycleCallbacks.sequential_epoch_end() + LifecycleCallbacks.sequential_epoch_end(subgraph=layer) + + print("\n+++++++++++++++++++++++++++++++++++++++++++++") + print("Auto-Search Summary (Standard 8-bit vs QuIP 4-bit):") + print(f"{'Layer':<40} | {'Mode':<10} | {'Cos Sim':<10} | {'Save(MB)':<10} | {'Drop'}") + for res in search_results: + print(f"{res['layer']:<40} | {res['best_mode']:<10} | {res['score']:.6f} | {res['size_saved_mb']:.2f} | {res['score_drop']:.6f}") + print("+++++++++++++++++++++++++++++++++++++++++++++\n") + + _sync_modifier_config_to_model(session, model, global_quip_targets) + + tokenizer = kwargs.get("tokenizer", None) + _simulate_and_verify(model, tokenizer) + + LifecycleCallbacks.calibration_epoch_end() + +def _apply_official_quip_transform(model, layer_name, layer_module, block_size=128): + print(f" >>> [QuIP Fix] Applying official QuIPModifier logic to {layer_name}...") + current_layer_targets = [] + for name, submodule in layer_module.named_modules(): + if isinstance(submodule, torch.nn.Linear) or "proj" in name: + if "observer" in name or "quip" in name: continue + if not hasattr(submodule, "weight"): continue + full_target = f"re:{layer_name}.{name}" + current_layer_targets.append(full_target) + + if not current_layer_targets: return + + modifier = QuIPModifier( + targets=current_layer_targets, + rotations=["v", "u"], + transform_block_size=block_size, + transform_type="hadamard", + ignore=["lm_head"] + ) + + if not modifier.initialized: + modifier.on_initialize(state=active_session().lifecycle) + _ensure_quip_weights_materialized(layer_module) + + modifier.on_finalize(state=active_session().lifecycle) + +def _ensure_quip_weights_materialized(module): + device = "cuda" if torch.cuda.is_available() else "cpu" + for name, child in module.named_modules(): + if "quip" in name: + for param_name, param in child.named_parameters(recurse=False): + if param.device.type == 'meta': + dim = param.shape[0] + H = get_hadamard_matrix(dim).to(dtype=torch.float16, device=device) + delattr(child, param_name) + child.register_parameter(param_name, torch.nn.Parameter(H)) + for buf_name, buf in child.named_buffers(recurse=False): + if buf.device.type == 'meta': + dim = buf.shape[0] + H = get_hadamard_matrix(dim).to(dtype=torch.float16, device=device) + setattr(child, buf_name, H) + +def _hadamard_unscaled(n: int, device, dtype=None) -> torch.Tensor: + if n < 1 or (n & (n - 1)) != 0: + raise ValueError(f"Hadamard size n must be a power of 2, got n={n}") + H = torch.ones((1, 1), device=device, dtype=dtype or torch.float32) + while H.shape[0] < n: + H = torch.cat([torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)], dim=0) + return H + +def _hadamard_orthonormal(n: int, device, dtype) -> torch.Tensor: + H = _hadamard_unscaled(n, device=device, dtype=dtype) + return H / math.sqrt(n) + +def _apply_rotation(weight, block_size=128): + out_features, in_features = weight.shape + device = weight.device + dtype = weight.dtype + H = _hadamard_orthonormal(block_size, device=device, dtype=dtype) + pad_in = (block_size - (in_features % block_size)) % block_size + w_padded_in = torch.nn.functional.pad(weight, (0, pad_in)) + w_v = torch.matmul(w_padded_in.view(out_features, -1, block_size), H) + w_v = w_v.view(out_features, -1) + pad_out = (block_size - (out_features % block_size)) % block_size + if pad_out > 0: w_v = torch.nn.functional.pad(w_v, (0, 0, 0, pad_out)) + w_v_reshaped = w_v.view(-1, block_size, w_v.shape[1]) + w_u = torch.matmul(H, w_v_reshaped) + w_final = w_u.view(-1, w_v.shape[1]) + return w_final, pad_in, pad_out, H + +class QuIPWrapper(torch.nn.Module): + def __init__(self, original_linear: torch.nn.Linear, block_size=128): + super().__init__() + self.block_size = block_size + with torch.no_grad(): + w_rotated, pad_in, pad_out, H = _apply_rotation(original_linear.weight.data, block_size) + self.pad_in = pad_in + self.pad_out = pad_out + self.register_buffer('H', H) + out_features_padded, in_features_padded = w_rotated.shape + self.linear = torch.nn.Linear(in_features_padded, out_features_padded, bias=original_linear.bias is not None) + self.linear.weight.data = w_rotated + if original_linear.bias is not None: + b_padded = torch.nn.functional.pad(original_linear.bias.data, (0, pad_out)) + b_reshaped = b_padded.view(-1, block_size, 1) + b_rotated = torch.matmul(self.H, b_reshaped).view(-1) + self.linear.bias.data = b_rotated + if hasattr(original_linear, "quantization_scheme"): + self.linear.quantization_scheme = copy.deepcopy(original_linear.quantization_scheme) + + def forward(self, x): + dtype = x.dtype + x = x.to(self.H.dtype) + if self.pad_in > 0: x = torch.nn.functional.pad(x, (0, self.pad_in)) + orig_shape = x.shape + x_reshaped = x.view(*orig_shape[:-1], -1, self.block_size) + x_rotated = torch.matmul(x_reshaped, self.H) + x_rotated = x_rotated.view(*orig_shape) + out = self.linear(x_rotated.to(dtype)) + out = out.to(self.H.dtype) + out_shape = out.shape + out_reshaped = out.view(*out_shape[:-1], -1, self.block_size) + out_unrotated = torch.matmul(out_reshaped, self.H) + out_final = out_unrotated.view(*out_shape) + if self.pad_out > 0: out_final = out_final[..., :-self.pad_out] + return out_final.to(dtype) + +def _get_scale_and_zeropoint(weight, bits): + w_max = weight.abs().amax(dim=1, keepdim=True) + max_q = 2**(bits - 1) - 1 + scale = w_max / max_q + scale = torch.clamp(scale, min=1e-5) + zero_point = torch.zeros_like(scale, dtype=torch.int32) + return scale, zero_point + +def _calculate_layer_metrics(layer, bits, use_quip=False): + """ + Calculates Cosine Similarity with STRICT filtering to avoid QuIP artifacts crash. + """ + total_cos = 0.0 + total_bytes = 0 + count = 0 + func_used = "unknown" + q_args = QuantizationArgs(num_bits=bits, symmetric=True) + + for name, submodule in layer.named_modules(): + + if any(x in name for x in ["observer", "v_input", "u_output", "quip"]): + continue + if "Hadamard" in submodule.__class__.__name__: + continue + + if not hasattr(submodule, "weight") or submodule.weight is None: continue + + if isinstance(submodule, torch.nn.Linear) or "proj" in name: + hook_triggered = False + if hasattr(submodule, "_hf_hook"): + try: + submodule._hf_hook.pre_forward(submodule) + hook_triggered = True + except Exception: pass + + weight = submodule.weight + if weight.device.type == 'meta': + if hook_triggered: submodule._hf_hook.post_forward(submodule, None) + continue + + weight_data = weight.data + total_bytes += weight_data.numel() * (bits / 8.0) + + try: + w_to_quant = weight_data + H_for_unrotate = None + pad_in = 0 + pad_out = 0 + + if use_quip: + if not hasattr(submodule, "bias"): continue + temp_wrapper = QuIPWrapper(submodule, block_size=128) + w_to_quant = temp_wrapper.linear.weight.data + H_for_unrotate = temp_wrapper.H + + if hasattr(temp_wrapper, 'pad_in'): pad_in = temp_wrapper.pad_in + elif hasattr(temp_wrapper, 'pad_len'): pad_in = temp_wrapper.pad_len + else: pad_in = 0 + pad_out = getattr(temp_wrapper, 'pad_out', 0) + + func_used = f"QuIP_Real({bits}b)" + else: + func_used = f"Std({bits}b)" + + scale, zero_point = _get_scale_and_zeropoint(w_to_quant, bits) + q_args.num_bits = bits + w_dq_rotated = fake_quantize(x=w_to_quant, scale=scale, zero_point=zero_point, args=q_args) + + if use_quip: + rows_padded = w_dq_rotated.shape[0] + cols_padded = w_dq_rotated.shape[1] + block_size = 128 + + w_dq_reshaped_u = w_dq_rotated.view(-1, block_size, cols_padded) + w_u_inv = torch.matmul(H_for_unrotate, w_dq_reshaped_u) + w_u_inv = w_u_inv.view(rows_padded, cols_padded) + if pad_out > 0: w_u_inv = w_u_inv[:-pad_out, :] + + rows_orig = w_u_inv.shape[0] + w_dq_reshaped_v = w_u_inv.view(rows_orig, -1, block_size) + w_recon = torch.matmul(w_dq_reshaped_v, H_for_unrotate) + w_recon = w_recon.view(rows_orig, cols_padded) + if pad_in > 0: w_recon = w_recon[:, :-pad_in] + + w_dq = w_recon + else: + w_dq = w_dq_rotated + + min_rows = min(weight_data.shape[0], w_dq.shape[0]) + min_cols = min(weight_data.shape[1], w_dq.shape[1]) + + cos_sim = torch.nn.functional.cosine_similarity( + weight_data[:min_rows, :min_cols].flatten(), + w_dq[:min_rows, :min_cols].flatten(), + dim=0, eps=1e-8 + ).item() + + if cos_sim > 1.0: cos_sim = 1.0 + total_cos += cos_sim + count += 1 + except Exception as e: + pass + finally: + if hook_triggered: submodule._hf_hook.post_forward(submodule, None) + + if count == 0: return 0.0, "none", 0 + return total_cos / count, func_used, total_bytes + +def _extract_real_scheme_from_module(layer, target_bits): + for name, submodule in layer.named_modules(): + if isinstance(submodule, torch.nn.Linear) or "proj" in name: + if hasattr(submodule, "quantization_scheme") and submodule.quantization_scheme: + scheme = submodule.quantization_scheme + if hasattr(scheme, "weights") and scheme.weights: + w_config = None + if hasattr(scheme.weights, "dict"): w_config = scheme.weights.dict() + elif isinstance(scheme.weights, dict): w_config = copy.deepcopy(scheme.weights) + else: + try: + w_config = { + "num_bits": scheme.weights.num_bits, + "group_size": getattr(scheme.weights, "group_size", 128), + "symmetric": getattr(scheme.weights, "symmetric", True), + "strategy": getattr(scheme.weights, "strategy", "group"), + "observer": getattr(scheme.weights, "observer", "minmax"), + "type": "int", + "actorder": getattr(scheme.weights, "actorder", None), + "block_structure": getattr(scheme.weights, "block_structure", None), + "dynamic": getattr(scheme.weights, "dynamic", False), + "observer_kwargs": getattr(scheme.weights, "observer_kwargs", {}), + } + except: pass + if w_config: + w_config["num_bits"] = target_bits + return w_config + return None + +def _set_layer_quantization_bits(session, layer, layer_name, target_bits, transform_scheme=None): + for name, submodule in layer.named_modules(): + if ("gate" in name and "proj" not in name) or name.endswith(".gate"): continue + if hasattr(submodule, "quantization_scheme") and submodule.quantization_scheme is not None: + submodule.quantization_scheme = copy.deepcopy(submodule.quantization_scheme) + if hasattr(submodule.quantization_scheme, 'weights') and submodule.quantization_scheme.weights is not None: + submodule.quantization_scheme.weights.num_bits = target_bits + + modifier = None + for m in session.lifecycle.recipe.modifiers: + if isinstance(m, QuantizationModifier): + modifier = m + break + if not modifier: return + if modifier.config_groups is None: modifier.config_groups = {} + + current_groups = {} + for k, v in modifier.config_groups.items(): + if hasattr(v, 'dict'): current_groups[k] = v.dict() + elif isinstance(v, dict): current_groups[k] = copy.deepcopy(v) + else: current_groups[k] = v + + target_group_key = "group_1" + if target_group_key in current_groups and "targets" in current_groups[target_group_key]: + old_targets = current_groups[target_group_key]["targets"] + prefix = f"re:{layer_name}." + new_targets = [t for t in old_targets if not t.startswith(prefix)] + current_groups[target_group_key]["targets"] = new_targets + + if "group_0" not in current_groups: + try: + def to_dict(obj): + if hasattr(obj, 'dict'): return obj.dict() + if isinstance(obj, dict): return obj + return obj + flat_weights = to_dict(modifier.weights) if hasattr(modifier, 'weights') else None + if not flat_weights or (isinstance(flat_weights, dict) and all(v is None for v in flat_weights.values())): + extracted = _extract_real_scheme_from_module(layer, 8) + if extracted: flat_weights = extracted + flat_targets = modifier.targets if hasattr(modifier, 'targets') and modifier.targets else ["Linear"] + group_0_dict = { + "format": "pack-quantized", + "input_activations": to_dict(modifier.input_activations) if hasattr(modifier, 'input_activations') else None, + "output_activations": to_dict(modifier.output_activations) if hasattr(modifier, 'output_activations') else None, + "targets": flat_targets, + "weights": flat_weights + } + if "ignore" in group_0_dict: del group_0_dict["ignore"] + current_groups["group_0"] = group_0_dict + if hasattr(modifier, 'targets'): modifier.targets = [] + if hasattr(modifier, 'weights'): modifier.weights = None + except Exception as e: print(f"WARNING: Failed to auto-create group_0: {e}") + + modifier.config_groups = current_groups + if target_bits == 8: return + + if target_group_key not in current_groups: + base_source = current_groups.get("group_0") + if not base_source: base_source = {"weights": {"num_bits": 4}, "targets": []} + new_group = copy.deepcopy(base_source) + new_group['targets'] = [] + real_scheme = _extract_real_scheme_from_module(layer, target_bits) + if real_scheme: new_group['weights'] = real_scheme + else: + if 'weights' in new_group and new_group['weights']: new_group['weights']['num_bits'] = target_bits + if "ignore" in new_group: del new_group["ignore"] + if "transform" in new_group: del new_group["transform"] + current_groups[target_group_key] = new_group + + target_group = current_groups[target_group_key] + if 'targets' not in target_group or target_group['targets'] is None: target_group['targets'] = [] + + for name, submodule in layer.named_modules(): + if not (isinstance(submodule, torch.nn.Linear) or "proj" in name): continue + if "observer" in name or "input" in name or "output" in name or "quip" in name: continue + full_target_name = f"re:{layer_name}.{name}" + if full_target_name not in target_group['targets']: + target_group['targets'].append(full_target_name) + modifier.config_groups = current_groups + +def _collapse_moe_targets(target_list): + if not target_list: return [] + + non_experts = [t for t in target_list if ".experts." not in t] + + collapsed_experts = set() + for t in target_list: + if ".experts." in t: + new_t = re.sub(r'\.experts\.\d+\.', '.experts.*.', t) + collapsed_experts.add(new_t) + + return sorted(non_experts + list(collapsed_experts)) + + +def _sync_modifier_config_to_model(session, model, quip_layers_list): + modifier = None + for m in session.lifecycle.recipe.modifiers: + if isinstance(m, QuantizationModifier): + modifier = m + break + if not modifier: return + + def layer_sort_key(s): + match = re.search(r"\.layers\.(\d+)\.", s) + if match: return int(match.group(1)), s + return 999999, s + + final_groups = {} + source_groups = modifier.config_groups or {} + + def to_dict_safe(obj): + if hasattr(obj, 'dict'): return obj.dict() + if isinstance(obj, dict): return obj + return obj + + g0_source = source_groups.get("group_0", {}) + g0_source = to_dict_safe(g0_source) + default_weights = { + "num_bits": 8, "type": "int", "symmetric": True, + "strategy": "tensor", "dynamic": False, "actorder": None + } + final_weights_0 = default_weights.copy() + if "weights" in g0_source and g0_source["weights"]: + source_w = to_dict_safe(g0_source["weights"]) + final_weights_0.update(source_w) + final_weights_0["num_bits"] = 8 + + input_acts = to_dict_safe(modifier.input_activations) if hasattr(modifier, 'input_activations') else None + output_acts = to_dict_safe(modifier.output_activations) if hasattr(modifier, 'output_activations') else None + + final_groups["group_0"] = { + "format": "pack-quantized", + "input_activations": input_acts, + "output_activations": output_acts, + "targets": ["Linear"], + "weights": final_weights_0 + } + + if "group_1" in source_groups: + v = source_groups["group_1"] + v = to_dict_safe(v) + v = copy.deepcopy(v) + if "targets" in v and v["targets"]: + v["targets"] = _collapse_moe_targets(v["targets"]) + v["targets"] = sorted(v["targets"], key=layer_sort_key) + final_groups["group_1"] = v + + transform_config_dict = {} + if quip_layers_list: + unique_targets = list(set(quip_layers_list)) + + collapsed_targets = _collapse_moe_targets(unique_targets) + sorted_targets = sorted(collapsed_targets, key=layer_sort_key) + + transform_config_dict = { + "config_groups": { + "u": { + "type": "hadamard", "head_dim": 128, "precision": "torch.float64", + "randomize": False, "requires_grad": False, + "apply": [ + {"location": "weight_output", "inverse": False, "targets": sorted_targets, "ignore": ["lm_head"]}, + {"location": "output", "inverse": True, "targets": sorted_targets, "ignore": ["lm_head"]} + ] + }, + "v": { + "type": "hadamard", "head_dim": 128, "precision": "torch.float64", + "randomize": False, "requires_grad": False, + "apply": [ + {"location": "input", "inverse": False, "targets": sorted_targets, "ignore": ["lm_head"]}, + {"location": "weight_input", "inverse": True, "targets": sorted_targets, "ignore": ["lm_head"]} + ] + } + } + } + + if not hasattr(model, 'config'): model.config = type('Config', (), {})() + if not hasattr(model.config, 'quantization_config') or model.config.quantization_config is None: + model.config.quantization_config = {} + + q_config_data = {"config_groups": final_groups, "quant_method": "compressed-tensors"} + + if isinstance(model.config.quantization_config, dict): + model.config.quantization_config.update(q_config_data) + if transform_config_dict: + model.config.quantization_config['transform_config'] = transform_config_dict + else: + model.config.quantization_config.config_groups = final_groups + model.config.quantization_config.quant_method = "compressed-tensors" + if transform_config_dict: + model.config.quantization_config.transform_config = transform_config_dict + + original_save = model.save_pretrained + def new_save_pretrained(save_directory, *args, **kwargs): + original_save(save_directory, *args, **kwargs) + print(f"DEBUG: Force overwriting config.json in {save_directory}...") + config_path = os.path.join(save_directory, "config.json") + try: + with open(config_path, 'r') as f: data = json.load(f) + if "quantization_config" not in data: data["quantization_config"] = {} + data["quantization_config"]["config_groups"] = final_groups + data["quantization_config"]["quant_method"] = "compressed-tensors" + if transform_config_dict: data["quantization_config"]["transform_config"] = transform_config_dict + with open(config_path, 'w') as f: json.dump(data, f, indent=2) + print("DEBUG: config.json overwritten with FULL STRUCTURE & SORTING!") + except Exception as e: print(f"WARNING: Failed to overwrite config.json: {e}") + model.save_pretrained = new_save_pretrained + +def _create_dummy_input(layer: torch.nn.Module, model: torch.nn.Module) -> Union[torch.Tensor, Dict[str, Any]]: + try: param = next(layer.parameters()) + except StopIteration: param = torch.tensor(0).cuda() if torch.cuda.is_available() else torch.tensor(0) + device = param.device + dtype = param.dtype + config = getattr(model, 'config', None) + hidden_size = getattr(config, 'hidden_size', param.shape[-1] if len(param.shape) > 0 else 4096) + hidden_states = torch.randn(1, 1, hidden_size, device=device, dtype=dtype) + if hasattr(layer, 'self_attn') or 'DecoderLayer' in layer.__class__.__name__: + return {'hidden_states': hidden_states, 'attention_mask': torch.ones(1, 1, device=device, dtype=torch.long)} + return hidden_states + +def _simulate_and_verify(model, tokenizer=None): + from functools import partial + print("\n+++++++++++++++++++++++++++++++++++++++++++++") + print(">>> [Simulation] Starting FakeQuant Inference Verification (Hook Method)...") + + if tokenizer is None: + try: + model_path = getattr(model.config, "_name_or_path", None) + if model_path: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + except Exception: pass + + def quantize_pre_hook(module, input, bits=8, group_size=128): + if not hasattr(module, "weight") or module.weight is None: return + w = module.weight + if w.device.type == 'meta': return + module._saved_weight_ref = w.data + + try: + out_f, in_f = w.shape + use_group = (group_size > 0) and (in_f % group_size == 0) + w_float = w.data.float() + + if use_group: + w_reshaped = w_float.view(out_f, -1, group_size) + w_max = w_reshaped.abs().amax(dim=-1, keepdim=True) + else: + w_reshaped = w_float + w_max = w_float.abs().amax(dim=1, keepdim=True) + + max_q = 2**(bits - 1) - 1 + scale = torch.clamp(w_max / max_q, min=1e-5) + zp = torch.zeros_like(scale, dtype=torch.int32) + + q_args = QuantizationArgs(num_bits=bits, symmetric=True) + w_fake_reshaped = fake_quantize(w_reshaped.to(w.dtype), scale.to(w.dtype), zp, q_args) + + if use_group: w_fake = w_fake_reshaped.view(out_f, in_f) + else: w_fake = w_fake_reshaped + + module.weight.data = w_fake + except Exception: + if hasattr(module, "_saved_weight_ref"): module.weight.data = module._saved_weight_ref + + def restore_post_hook(module, input, output): + if hasattr(module, "_saved_weight_ref"): + module.weight.data = module._saved_weight_ref + del module._saved_weight_ref + + hooks = [] + skip_names = ["lm_head", "output"] + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) or "proj" in name: + if any(s in name for s in skip_names): continue + if any(x in name for x in ["v_input", "u_output", "quip"]): continue + + if not hasattr(module, "weight") or module.weight is None: continue + + bits = 16 + if hasattr(module, "quantization_scheme") and module.quantization_scheme: + bits = module.quantization_scheme.weights.num_bits + + if bits < 16: + h1 = module.register_forward_pre_hook(partial(quantize_pre_hook, bits=bits)) + h2 = module.register_forward_hook(restore_post_hook) + hooks.extend([h1, h2]) + + print(f" Registered hooks on {len(hooks)//2} layers.") + + test_cases = [ + {"prompt": "1 + 1 ="}, + {"prompt": "The capital of China is"}, + {"prompt": "Hello, my name is"} + ] + + if tokenizer: + model.eval() + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + + try: + target_device = next(p.device for p in model.parameters() if p.device.type != 'meta') + except: target_device = "cuda:0" + + for i, case in enumerate(test_cases): + prompt = case["prompt"] + print(f" --- Test {i+1}: {prompt}") + try: + inputs = tokenizer(prompt, return_tensors="pt").to(target_device) + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=10, + do_sample=False, + pad_token_id=tokenizer.pad_token_id + ) + + res = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip() + if "stop" in case and case["stop"] in res: + res = res.split(case["stop"])[0] + + print(f" [Answer]: \033[92m{res}\033[0m") + except Exception as e: + print(f" [Error]: {e}") + else: + print(" [Skip] No tokenizer available.") + + for h in hooks: h.remove() + for module in model.modules(): + if hasattr(module, "_saved_weight_ref"): del module._saved_weight_ref + if torch.cuda.is_available(): torch.cuda.empty_cache() + print("+++++++++++++++++++++++++++++++++++++++++++++\n") + diff --git a/flagscale/runner/runner_compress.py b/flagscale/runner/runner_compress.py new file mode 100644 index 0000000000..10fa83d5b3 --- /dev/null +++ b/flagscale/runner/runner_compress.py @@ -0,0 +1,280 @@ +import os +import shlex +import time +from datetime import datetime + +import hydra +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf + +from flagscale.runner.runner_base import JobStatus, RunnerBase +from flagscale.runner.runner_utils import ( + add_decive_extra_config, + flatten_dict_to_args, + get_free_port, + get_host_name_or_ip, + get_nnodes, + get_nproc_per_node, + logger, + parse_hostfile, + run_local_command, + run_scp_command, + run_ssh_command, +) + + +def _get_args_llmcompressor(config: DictConfig): + # see the following link for more details + # https://github.com/facebookresearch/hydra/discussions/2750 + # OmegaConf.set_struct(config, False) + + hydra_config = HydraConfig.get() + output_dir = hydra_config.runtime.output_dir + output_subdir = hydra_config.output_subdir + config_path = os.path.join(output_dir, f"{output_subdir}/config.yaml") + config_path = hydra.utils.to_absolute_path(config_path) + + args = [] + args.append(f"--config-path={config_path}") + + return args + + +def _update_config_compress(config: DictConfig): + exp_dir = os.path.abspath(config.experiment.exp_dir) + if not os.path.isdir(exp_dir): + os.makedirs(exp_dir) + assert os.path.isdir(exp_dir), f"Directory {exp_dir} does not exist." + + OmegaConf.set_struct(config, False) + config = config.compress.system + + wandb_dir = ( + os.path.abspath(config.logging.wandb_save_dir) + if config.logging.get("wandb_save_dir", None) + else os.path.join(exp_dir, "wandb") + ) + tensorboard_dir = ( + os.path.abspath(config.logging.tensorboard_dir) + if config.logging.get("tensorboard_dir", None) + else os.path.join(exp_dir, "tensorboard") + ) + log_dir = ( + os.path.abspath(config.logging.log_dir) + if config.logging.get("log_dir", None) + else os.path.join(exp_dir, "logs") + ) + + log_dir = os.path.join(exp_dir, f"compress_logs") + scripts_dir = os.path.join(log_dir, "scripts") + pids_dir = os.path.join(log_dir, "pids") + + config.logging.log_dir = log_dir + config.logging.scripts_dir = scripts_dir + config.logging.pids_dir = pids_dir + config.logging.tensorboard_dir = tensorboard_dir + config.logging.wandb_save_dir = wandb_dir + + OmegaConf.set_struct(config, True) + + +def _generate_run_script_compress( + config, host, node_rank, cmd, background=True, with_test=False +): + system_config = config.compress.system + logging_config = config.compress.system.logging + + no_shared_fs = config.experiment.runner.get("no_shared_fs", False) + if no_shared_fs: + host_output_file = os.path.join(logging_config.log_dir, f"host.output") + else: + host_output_file = os.path.join( + logging_config.log_dir, f"host_{node_rank}_{host}.output" + ) + host_run_script_file = os.path.join( + logging_config.scripts_dir, f"host_{node_rank}_{host}_run.sh" + ) + host_pid_file = os.path.join( + logging_config.pids_dir, f"host_{node_rank}_{host}.pid" + ) + + os.makedirs(logging_config.scripts_dir, exist_ok=True) + + root_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + compress_dir = os.path.join(root_dir, "compress") + ### set megatron dir for dataset + megtron_dir = os.path.join(root_dir, "megatron") + cmds_config = config.experiment.get("cmds", None) + if cmds_config: + before_start = cmds_config.get("before_start", "") + else: + before_start = "" + with open(host_run_script_file, "w") as f: + f.write("#!/bin/bash\n\n") + f.write(f"{before_start}\n") + f.write(f"mkdir -p {system_config.save_dir}\n") + f.write(f"mkdir -p {system_config.logging.log_dir}\n") + f.write(f"mkdir -p {system_config.logging.pids_dir}\n") + f.write(f"mkdir -p {system_config.logging.tensorboard_dir}\n") + f.write(f"mkdir -p {system_config.logging.wandb_save_dir}\n") + f.write(f"\n") + f.write(f"cd {root_dir}\n") + f.write(f"\n") + f.write(f"export PYTHONPATH={compress_dir}:{megtron_dir}:{root_dir}\n") + f.write(f"\n") + f.write(f'cmd="{cmd}"\n') + f.write(f"\n") + if with_test: + f.write(f'bash -c "$cmd; sync" \n') + else: + # TODO: need a option to control whether to append or overwrite the output file + # Now, it always appends to the output file + if background: + f.write( + f'nohup bash -c "$cmd; sync" >> {host_output_file} 2>&1 & echo $! > {host_pid_file}\n' + ) + else: + f.write(f'bash -c "$cmd; sync" >> {host_output_file} 2>&1\n') + f.write("\n") + f.flush() + os.fsync(f.fileno()) + os.chmod(host_run_script_file, 0o755) + + return host_run_script_file + + +class SSHCompressRunner(RunnerBase): + def __init__(self, config: DictConfig): + super().__init__(config) + self.task_type = getattr(self.config.experiment.task, "type", None) + assert self.task_type == "compress", f"Unsupported task type: {self.task_type}" + self._prepare() + + def _prepare(self): + _update_config_compress(self.config) + self.user_args = _get_args_llmcompressor(self.config) + self.rdzv_id = datetime.now().strftime("%Y%m%d_%H%M%S.%f") + self.user_envs = self.config.experiment.get("envs", {}) + self.cur_envs = None # current node envs + self.user_script = self.config.experiment.task.entrypoint + self.resources = parse_hostfile( + self.config.experiment.runner.get("hostfile", None) + ) + logger.info("\n************** configuration **************") + logger.info(f"\n{OmegaConf.to_yaml(self.config)}") + + def _run_each( + self, + host, + master_addr, + master_port, + nnodes, + node_rank, + nproc_per_node, + with_test=False, + dryrun=False, + ): + export_cmd = [] + for k, v in self.user_envs.items(): + export_cmd += [f"{k}={v}"] + + cmd = shlex.join(export_cmd + ["python"] + [self.user_script] + self.user_args) + + logging_config = self.config.compress.system.logging + host_run_script_file = _generate_run_script_compress( + self.config, host, node_rank, cmd, background=True, with_test=with_test + ) + + if host != "localhost": + ssh_port = self.config.experiment.runner.get("ssh_port", 22) + # Step 1: make sure the scripts_dir exists on the remote host + run_ssh_command( + host, f"mkdir -p {logging_config.scripts_dir}", ssh_port, dryrun + ) + + # Step 2: copy the host_run_script_file to the remote host + no_shared_fs = self.config.experiment.runner.get("no_shared_fs", False) + if no_shared_fs: + run_scp_command( + host, + host_run_script_file, + logging_config.scripts_dir, + ssh_port, + dryrun, + ) + + # Step 3: run the host_run_script_file on the remote host + run_ssh_command(host, f"bash {host_run_script_file}", ssh_port, dryrun) + else: + run_local_command(f"bash {host_run_script_file}", dryrun) + + def run(self, with_test=False, dryrun=False): + num_visible_devices = None + visible_devices = self.user_envs.get("CUDA_VISIBLE_DEVICES", None) + if visible_devices is not None and isinstance(visible_devices, str): + visible_devices = visible_devices.split(",") + num_visible_devices = len(visible_devices) + + runner_config = self.config.experiment.runner + + # If hostfile is provided, use the resources from the hostfile + if self.resources is not None: + nnodes_from_hostfile = len(self.resources.keys()) + nnodes_from_args = runner_config.get("nnodes", None) + nnodes = get_nnodes(nnodes_from_hostfile, nnodes_from_args) + available_ip = list(self.resources.keys())[0] + available_port = get_free_port() + for node_rank, (host, resource_info) in enumerate(self.resources.items()): + if node_rank >= nnodes: + break + nproc_from_hostfile = resource_info["slots"] + nproc_from_args = runner_config.get("nproc_per_node", None) + nproc_per_node = get_nproc_per_node( + nproc_from_hostfile, nproc_from_args, num_visible_devices + ) + master_addr = runner_config.get("master_addr", available_ip) + master_port = runner_config.get("master_port", available_port) + self._run_each( + host, + master_addr, + master_port, + nnodes, + node_rank, + nproc_per_node, + with_test=with_test, + dryrun=dryrun, + ) + else: + # If hostfile is not provided, run the job on localhost + nproc_from_args = runner_config.get("nproc_per_node", None) + nproc_per_node = get_nproc_per_node( + None, nproc_from_args, num_visible_devices + ) + available_addr = runner_config.get("master_addr", "localhost") + available_port = runner_config.get("master_port", get_free_port()) + self._run_each( + "localhost", + available_addr, + available_port, + 1, + 0, + nproc_per_node, + with_test=with_test, + dryrun=dryrun, + ) + + def stop(self): + if self.resources is None: + self._stop_each("localhost", 0) + return + + nnodes = get_nnodes( + len(self.resources), self.config.experiment.runner.get("nnodes", None) + ) + + for node_rank, (host, _) in enumerate(self.resources.items()): + if node_rank >= nnodes: + break + self._stop_each(host, node_rank)