diff --git a/deepcompressor/app/diffusion/nn/patch.py b/deepcompressor/app/diffusion/nn/patch.py index a39ff40..aa6b69d 100644 --- a/deepcompressor/app/diffusion/nn/patch.py +++ b/deepcompressor/app/diffusion/nn/patch.py @@ -2,18 +2,22 @@ from diffusers.models.attention_processor import Attention from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock -from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d +from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d, Conv2dAsLinear from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear from deepcompressor.utils import patch, tools from .attention import DiffusionAttentionProcessor from .struct import DiffusionFeedForwardStruct, DiffusionModelStruct, DiffusionResnetStruct, UNetStruct +from diffusers.pipelines import StableDiffusionXLPipeline +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from diffusers.models.resnet import ResnetBlock2D __all__ = [ "replace_up_block_conv_with_concat_conv", "replace_fused_linear_with_concat_linear", "replace_attn_processor", "shift_input_activations", + "replace_conv2d_with_equivalent_linear_in_sdxl", ] @@ -52,6 +56,64 @@ def replace_up_block_conv_with_concat_conv(model: nn.Module) -> None: tools.logging.Formatter.indent_dec() tools.logging.Formatter.indent_dec() tools.logging.Formatter.indent_dec() + + +def replace_conv2d_with_equivalent_linear_in_sdxl(pipeline: StableDiffusionXLPipeline) -> None: + """Replace Conv2d layers in SDXL UNet2DConditionModel with equivalent linear.""" + if not isinstance(pipeline, StableDiffusionXLPipeline): + return + logger = tools.logging.getLogger(__name__) + logger.info("Replacing Conv2d with equivalent linear.") + tools.logging.Formatter.indent_inc() + assert isinstance(pipeline.unet, UNet2DConditionModel) + model: UNet2DConditionModel = pipeline.unet + # logger.info("replacing conv_in") + # model.conv_in = Conv2dAsLinear(model.conv_in) + logger.info("+ replacing down_blocks") + + tools.logging.Formatter.indent_inc() + for i, down_block in enumerate(model.down_blocks): + assert hasattr(down_block, "resnets") + assert isinstance(getattr(down_block, "resnets"), nn.ModuleList) + for j, resnet in enumerate(getattr(down_block, "resnets")): + assert isinstance(resnet, ResnetBlock2D) + logger.info(f"- replacing down_blocks.{i}.resnets.{j}.conv1") + resnet.conv1 = Conv2dAsLinear(resnet.conv1) + logger.info(f"- replacing down_blocks.{i}.resnets.{j}.conv2") + resnet.conv2 = Conv2dAsLinear(resnet.conv2) + # logger.info(f"- replacing down_blocks.{i}.resnets.{j}.conv_shortcut") + # resnet.conv_shortcut = _convert(resnet.conv_shortcut) + tools.logging.Formatter.indent_dec() + + logger.info("+ replacing mid_block") + assert hasattr(model.mid_block, "resnets") + assert isinstance(getattr(model.mid_block, "resnets"), nn.ModuleList) + tools.logging.Formatter.indent_inc() + for j, resnet in enumerate(getattr(model.mid_block, "resnets")): + assert isinstance(resnet, ResnetBlock2D) + logger.info(f"- replacing mid_block.resnets.{j}.conv1") + resnet.conv1 = Conv2dAsLinear(resnet.conv1) + logger.info(f"- replacing mid_block.resnets.{j}.conv2") + resnet.conv2 = Conv2dAsLinear(resnet.conv2) + tools.logging.Formatter.indent_dec() + + logger.info("+ replacing up_blocks") + tools.logging.Formatter.indent_inc() + for i, up_block in enumerate(model.up_blocks): + assert hasattr(up_block, "resnets") + assert isinstance(getattr(up_block, "resnets"), nn.ModuleList) + for j, resnet in enumerate(getattr(up_block, "resnets")): + assert isinstance(resnet, ResnetBlock2D) + logger.info(f"- replacing up_blocks.{i}.resnets.{j}.conv1") + resnet.conv1 = Conv2dAsLinear(resnet.conv1) + logger.info(f"- replacing up_blocks.{i}.resnets.{j}.conv2") + resnet.conv2 = Conv2dAsLinear(resnet.conv2) + # logger.info(f"- replacing up_blocks.{i}.resnets.{j}.conv_shortcut") + # resnet.conv_shortcut = _convert(resnet.conv_shortcut) + tools.logging.Formatter.indent_dec() + # logger.info("replacing conv_out") + # model.conv_out = Conv2dAsLinear(model.conv_out) + tools.logging.Formatter.indent_dec() def replace_fused_linear_with_concat_linear(model: nn.Module) -> None: diff --git a/deepcompressor/app/diffusion/nn/struct.py b/deepcompressor/app/diffusion/nn/struct.py index a25c4f4..96843c4 100644 --- a/deepcompressor/app/diffusion/nn/struct.py +++ b/deepcompressor/app/diffusion/nn/struct.py @@ -58,7 +58,7 @@ StableDiffusionXLPipeline, ) -from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d +from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d, Conv2dAsLinear from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear from deepcompressor.nn.struct.attn import ( AttentionConfigStruct, @@ -70,6 +70,7 @@ ) from deepcompressor.nn.struct.base import BaseModuleStruct from deepcompressor.utils.common import join_name +from deepcompressor.utils import tools from .attention import DiffusionAttentionProcessor @@ -966,6 +967,7 @@ def construct( idx: int = 0, **kwargs, ) -> "DiffusionResnetStruct": + logger = tools.logging.getLogger(__name__) if isinstance(module, ResnetBlock2D): assert module.upsample is None, "upsample must be None" assert module.downsample is None, "downsample must be None" @@ -986,6 +988,9 @@ def construct( shifted = True conv1_convs = [module.conv1.conv] conv1_names = ["conv1.conv"] + elif isinstance(module.conv1, Conv2dAsLinear): + conv1_convs, conv1_names = [module.conv1], ["conv1"] + logger.info(f"construct conv1 Conv2dAsLinear in DiffusionResnetStruct, fname={fname}, rname={rname}, rkey={rkey}, idx={idx}") else: assert isinstance(module.conv1, nn.Conv2d) conv1_convs, conv1_names = [module.conv1], ["conv1"] @@ -1004,8 +1009,11 @@ def construct( shifted = True conv2_convs = [module.conv2.conv] conv2_names = ["conv2.conv"] + elif isinstance(module.conv2, Conv2dAsLinear): + conv2_convs, conv2_names = [module.conv2], ["conv2"] + logger.info(f"construct conv2 Conv2dAsLinear in DiffusionResnetStruct, fname={fname}, rname={rname}, rkey={rkey}, idx={idx}") else: - assert isinstance(module.conv2, nn.Conv2d) + assert isinstance(module.conv2, (nn.Conv2d, Conv2dAsLinear)) conv2_convs, conv2_names = [module.conv2], ["conv2"] convs, conv_rnames = [conv1_convs, conv2_convs], [conv1_names, conv2_names] norms, norm_rnames = [module.norm1, module.norm2], ["norm1", "norm2"] @@ -1017,8 +1025,8 @@ def construct( else: raise NotImplementedError(f"Unsupported module type: {type(module)}") config = FeedForwardConfigStruct( - hidden_size=convs[0][0].weight.shape[1], - intermediate_size=convs[0][0].weight.shape[0], + hidden_size=convs[0][0].linear.weight.shape[1], # TODO verify + intermediate_size=convs[0][0].linear.weight.shape[0], # TODO verify intermediate_act_type=act_type, num_experts=1, ) @@ -1518,6 +1526,7 @@ def _default_construct( @classmethod def _get_default_key_map(cls) -> dict[str, set[str]]: """Get the default allowed keys.""" + logger = tools.logging.getLogger(__name__) key_map: dict[str, set[str]] = defaultdict(set) for idx, (block_key, block_cls) in enumerate( ( @@ -1561,7 +1570,9 @@ def _get_default_key_map(cls) -> dict[str, set[str]]: if key in key_map: key_map[key].clear() key_map[key].add(key) - return {k: v for k, v in key_map.items() if v} + ret = {k: v for k, v in key_map.items() if v} + logger.info(f">> key_map: {ret}") + return ret @dataclass(kw_only=True) diff --git a/deepcompressor/app/diffusion/pipeline/config.py b/deepcompressor/app/diffusion/pipeline/config.py index c286d42..62ed593 100644 --- a/deepcompressor/app/diffusion/pipeline/config.py +++ b/deepcompressor/app/diffusion/pipeline/config.py @@ -2,6 +2,8 @@ """Diffusion pipeline configuration module.""" import gc +import json +import os import typing as tp from dataclasses import dataclass, field @@ -11,12 +13,16 @@ DiffusionPipeline, FluxControlPipeline, FluxFillPipeline, + FluxPipeline, SanaPipeline, ) +from diffusers.models.transformers import FluxTransformer2DModel from omniconfig import configclass from torch import nn from transformers import PreTrainedModel, PreTrainedTokenizer, T5EncoderModel +from safetensors.torch import load_file +from deepcompressor.app.diffusion.pipeline.utils.mapping import comfy_to_diffusers from deepcompressor.data.utils.dtype import eval_dtype from deepcompressor.quantizer.processor import Quantizer from deepcompressor.utils import tools @@ -28,6 +34,7 @@ replace_fused_linear_with_concat_linear, replace_up_block_conv_with_concat_conv, shift_input_activations, + replace_conv2d_with_equivalent_linear_in_sdxl, ) __all__ = ["DiffusionPipelineConfig"] @@ -357,6 +364,16 @@ def _default_build( pipeline.text_encoder.to(dtype) else: pipeline = SanaPipeline.from_pretrained(path, torch_dtype=dtype) + elif name == "flux.1-dev-custom": + with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as f: + json_config = json.load(f) + transformer = FluxTransformer2DModel.from_config(json_config).to(dtype) + checkpoint_file = os.path.basename(path) + ".safetensors" + state_dict = comfy_to_diffusers(load_file(os.path.join(path, checkpoint_file))) + transformer.load_state_dict(state_dict) + pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 + ) else: pipeline = AutoPipelineForText2Image.from_pretrained(path, torch_dtype=dtype) pipeline = pipeline.to(device) @@ -365,6 +382,7 @@ def _default_build( # The quantization and inference for resblock_conv layer is not completed. # So here we do not do any pre process to conv layers. # replace_up_block_conv_with_concat_conv(model) + replace_conv2d_with_equivalent_linear_in_sdxl(pipeline) if shift_activations: shift_input_activations(model) return pipeline diff --git a/deepcompressor/app/diffusion/ptq.py b/deepcompressor/app/diffusion/ptq.py index 74c1c17..ab67aca 100644 --- a/deepcompressor/app/diffusion/ptq.py +++ b/deepcompressor/app/diffusion/ptq.py @@ -383,6 +383,9 @@ def main(config: DiffusionPtqRunConfig, logging_level: int = tools.logging.DEBUG try: main(config, logging_level=tools.logging.DEBUG) except Exception as e: + torch.cuda.synchronize() + torch.cuda.memory._dump_snapshot("/data/dongd/ptq_err_snapshot.pickle") + torch.cuda.memory._record_memory_history(enabled=None) tools.logging.Formatter.indent_reset() tools.logging.error("=== Error ===") tools.logging.error(traceback.format_exc()) diff --git a/deepcompressor/app/diffusion/quant/activation.py b/deepcompressor/app/diffusion/quant/activation.py index 7e0dd27..882bc69 100644 --- a/deepcompressor/app/diffusion/quant/activation.py +++ b/deepcompressor/app/diffusion/quant/activation.py @@ -10,7 +10,9 @@ from deepcompressor.data.cache import IOTensorsCache from deepcompressor.data.common import TensorType +from deepcompressor.nn.patch.conv import Conv2dAsLinear from deepcompressor.utils import tools +from deepcompressor.utils.common import join_name from ..nn.struct import ( DiffusionAttentionStruct, @@ -124,11 +126,20 @@ def quantize_diffusion_block_activations( # noqa: C901 assert field_name == "add_k_proj" assert module_name == parent.add_k_proj_name modules, module_names = parent.add_qkv_proj, parent.add_qkv_proj_names + elif isinstance(module, Conv2dAsLinear): + modules = [module.linear] + module_names = [join_name(module_name, "linear")] + eval_module = module + eval_name = module_name + eval_kwargs = layer_kwargs if layer_kwargs else {} if modules is None: assert module not in used_modules used_modules.add(module) orig_wgts = [(module.weight, orig_state_dict[f"{module_name}.weight"])] if orig_state_dict else None args_caches.append((module_key, In, [module], [module_name], module, module_name, None, orig_wgts)) + elif isinstance(module , Conv2dAsLinear): + orig_wgts = [(module.linear.weight, orig_state_dict[f"{join_name(module_name, "linear")}.weight"])] if orig_state_dict else None + args_caches.append((module_key, In, modules, module_names, eval_module, eval_name, eval_kwargs, orig_wgts)) else: orig_wgts = [] for proj_module in modules: @@ -166,6 +177,8 @@ def quantize_diffusion_block_activations( # noqa: C901 key=module_key, tensor_type=tensor_type, ) + if isinstance(eval_module, Conv2dAsLinear): + logger.info(f"[calib_dynamic_range]Conv2dAsLinear eval_name: {eval_name}, module_names: {module_names}, quantizer.is_enabled: {quantizer.is_enabled()}") if quantizer.is_enabled(): if cache_keys[0] not in quantizer_state_dict: logger.debug("- Calibrating %s", ", ".join(cache_keys)) diff --git a/deepcompressor/app/diffusion/quant/smooth.py b/deepcompressor/app/diffusion/quant/smooth.py index 22f7b12..314b7e7 100644 --- a/deepcompressor/app/diffusion/quant/smooth.py +++ b/deepcompressor/app/diffusion/quant/smooth.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Diffusion smooth quantization module.""" +import gc import typing as tp import torch @@ -9,8 +10,10 @@ from deepcompressor.calib.smooth import ActivationSmoother, smooth_linear_modules from deepcompressor.data.cache import IOTensorsCache +from deepcompressor.nn.patch.conv import Conv2dAsLinear from deepcompressor.quantizer import Quantizer from deepcompressor.utils import tools +from deepcompressor.utils.common import join_name from deepcompressor.utils.hooks import KeyedInputPackager from ..nn.struct import ( @@ -71,8 +74,8 @@ def smooth_diffusion_qkv_proj( attn.qkv_proj, scale=smooth_cache.get(cache_key, None), config=config.smooth.proj, - weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank), - input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key), + weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank, develop_dtype=config.develop_dtype), + input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key, develop_dtype=config.develop_dtype), inputs=block_cache[attn.q_proj_name].inputs if block_cache else None, eval_inputs=block_cache[attn.name].inputs if block_cache else None, eval_module=attn, @@ -114,8 +117,8 @@ def smooth_diffusion_qkv_proj( attn.add_qkv_proj, scale=smooth_cache.get(cache_key, None), config=config.smooth.proj, - weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank), - input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key), + weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank, develop_dtype=config.develop_dtype), + input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key, develop_dtype=config.develop_dtype), inputs=block_cache[attn.add_k_proj_name].inputs if block_cache else None, eval_inputs=block_cache[attn.name].inputs if block_cache else None, eval_module=wrap_joint_attn(attn, indexes=1) if attn.is_joint_attn() else attn, @@ -166,8 +169,8 @@ def smooth_diffusion_out_proj( # noqa: C901 attn.o_proj, scale=smooth_cache.get(cache_key, None), config=config.smooth.proj, - weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank), - input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key), + weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank, develop_dtype=config.develop_dtype), + input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key, develop_dtype=config.develop_dtype), inputs=block_cache[attn.o_proj_name].inputs if block_cache else None, eval_inputs=block_cache[attn.o_proj_name].inputs if block_cache else None, eval_module=attn.o_proj, @@ -187,8 +190,8 @@ def smooth_diffusion_out_proj( # noqa: C901 attn.add_o_proj, scale=smooth_cache.get(cache_key, None), config=config.smooth.proj, - weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank), - input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key), + weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank, develop_dtype=config.develop_dtype), + input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key, develop_dtype=config.develop_dtype), inputs=block_cache[attn.add_o_proj_name].inputs if block_cache else None, eval_inputs=block_cache[attn.add_o_proj_name].inputs if block_cache else None, eval_module=attn.add_o_proj, @@ -208,8 +211,8 @@ def smooth_diffusion_out_proj( # noqa: C901 [attn.o_proj, attn.add_o_proj], scale=smooth_cache.get(cache_key, None), config=config.smooth.proj, - weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank), - input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key), + weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank, develop_dtype=config.develop_dtype), + input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key, develop_dtype=config.develop_dtype), inputs=block_cache[attn.o_proj_name].inputs if block_cache else None, eval_inputs=block_cache[attn.name].inputs if block_cache else None, eval_module=wrap_joint_attn(attn, indexes=(0, 1)), @@ -262,8 +265,8 @@ def smooth_diffusion_up_proj( ffn.up_projs, scale=smooth_cache.get(cache_key, None), config=config.smooth.proj, - weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank), - input_quantizer=Quantizer(config.ipts, channels_dim=channels_dim, key=module_key), + weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank, develop_dtype=config.develop_dtype), + input_quantizer=Quantizer(config.ipts, channels_dim=channels_dim, key=module_key, develop_dtype=config.develop_dtype), inputs=block_cache[ffn.up_proj_name].inputs if block_cache else None, eval_inputs=block_cache[ffn.up_proj_name].inputs if block_cache else None, eval_module=ffn.up_proj, @@ -285,7 +288,8 @@ def smooth_diffusion_down_proj( ) -> dict[str, torch.Tensor]: logger = tools.logging.getLogger(f"{__name__}.SmoothQuant") # ffn down projection - module_key = ffn.down_proj_key.upper() + # module_key = ffn.down_proj_key.upper() + module_key = ffn.down_proj_key needs_quant = config.enabled_wgts and config.wgts.is_enabled_for(module_key) needs_quant = needs_quant or (config.enabled_ipts and config.ipts.is_enabled_for(module_key)) if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key): @@ -301,8 +305,8 @@ def smooth_diffusion_down_proj( ffn.down_proj, scale=smooth_cache.get(cache_key, None), config=config.smooth.proj, - weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank), - input_quantizer=Quantizer(config_ipts, channels_dim=channels_dim, key=module_key), + weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank, develop_dtype=config.develop_dtype), + input_quantizer=Quantizer(config_ipts, channels_dim=channels_dim, key=module_key, develop_dtype=config.develop_dtype), inputs=block_cache[ffn.down_proj_name].inputs if block_cache else None, eval_inputs=block_cache[ffn.down_proj_name].inputs if block_cache else None, eval_module=ffn.down_proj, @@ -341,8 +345,8 @@ def smooth_diffusion_parallel_qkv_up_proj( modules, scale=smooth_cache.get(cache_key, None), config=config.smooth.proj, - weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank), - input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key), + weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank, develop_dtype=config.develop_dtype), + input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key, develop_dtype=config.develop_dtype), inputs=block_cache[attn.q_proj_name].inputs if block_cache else None, eval_inputs=block_cache[block.name].inputs if block_cache else None, eval_module=block, @@ -385,8 +389,8 @@ def smooth_diffusion_parallel_qkv_up_proj( modules, scale=smooth_cache.get(cache_key, None), config=config.smooth.proj, - weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank), - input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key), + weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank, develop_dtype=config.develop_dtype), + input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key, develop_dtype=config.develop_dtype), inputs=block_cache[attn.add_k_proj_name].inputs if block_cache else None, eval_inputs=block_cache[block.name].inputs if block_cache else None, eval_module=block, @@ -510,8 +514,8 @@ def smooth_diffusion_module( module, scale=smooth_cache.get(cache_key, None), config=config.smooth.proj, - weight_quantizer=Quantizer(config_wgts, key=module_key), - input_quantizer=Quantizer(config.ipts, channels_dim=channels_dim, key=module_key), + weight_quantizer=Quantizer(config_wgts, key=module_key, develop_dtype=config.develop_dtype), + input_quantizer=Quantizer(config.ipts, channels_dim=channels_dim, key=module_key, develop_dtype=config.develop_dtype), inputs=layer_cache[module_name].inputs if layer_cache else None, eval_inputs=layer_cache[module_name].inputs if layer_cache else None, eval_module=module, @@ -549,6 +553,8 @@ def smooth_diffusion_layer( """ logger = tools.logging.getLogger(f"{__name__}.SmoothQuant") logger.debug("- Smoothing Diffusion Block %s", layer.name) + # if layer.name == "up_blocks.1": + # torch.cuda.memory._record_memory_history() tools.logging.Formatter.indent_inc() layer_cache = layer_cache or {} layer_kwargs = layer_kwargs or {} @@ -579,6 +585,16 @@ def smooth_diffusion_layer( block_kwargs=layer_kwargs, ) tools.logging.Formatter.indent_dec() + elif isinstance(module, Conv2dAsLinear): + logger.debug(f"- Smoothing Conv2dAsLinear module_name: {module_name}. module_key: {module_key}") + smooth_cache = smooth_diffusion_module( + module_key=module_key, # TODO verify + module_name=join_name(module_name, "linear"), # TODO verify + module=module.linear, + config=config, + smooth_cache=smooth_cache, + layer_cache=layer_cache, + ) elif isinstance(module, (nn.Linear, nn.Conv2d)): smooth_cache = smooth_diffusion_module( module_key=module_key, @@ -594,7 +610,12 @@ def smooth_diffusion_layer( if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key): raise NotImplementedError(f"Module {module_name} is not supported for smoothing") logger.debug("- Skipping Module %s", module_name) + gc.collect() + torch.cuda.empty_cache() tools.logging.Formatter.indent_dec() + # if layer.name == "up_blocks.1": + # torch.cuda.memory._dump_snapshot("/data/dongd/upblocks.1_snapshot.pickle") + # torch.cuda.memory._record_memory_history(enabled=None) @torch.inference_mode() diff --git a/deepcompressor/app/diffusion/quant/weight.py b/deepcompressor/app/diffusion/quant/weight.py index c9a046d..35f3765 100644 --- a/deepcompressor/app/diffusion/quant/weight.py +++ b/deepcompressor/app/diffusion/quant/weight.py @@ -10,8 +10,10 @@ from deepcompressor.data.cache import IOTensorsCache from deepcompressor.data.zero import ZeroPointDomain +from deepcompressor.nn.patch.conv import Conv2dAsLinear from deepcompressor.nn.patch.lowrank import LowRankBranch from deepcompressor.utils import tools +from deepcompressor.utils.common import join_name from ..nn.struct import DiffusionAttentionStruct, DiffusionBlockStruct, DiffusionModelStruct, DiffusionModuleStruct from .config import DiffusionQuantConfig @@ -46,6 +48,8 @@ def calibrate_diffusion_block_low_rank_branch( # noqa: C901 assert config.wgts.low_rank is not None logger = tools.logging.getLogger(f"{__name__}.WeightQuantSVD") logger.debug("- Calibrating low-rank branches of block %s", layer.name) + if layer.name == "up_blocks.2": + torch.cuda.memory._record_memory_history() layer_cache = layer_cache or {} layer_kwargs = layer_kwargs or {} for module_key, module_name, module, parent, field_name in layer.named_key_modules(): @@ -89,6 +93,8 @@ def calibrate_diffusion_block_low_rank_branch( # noqa: C901 if isinstance(modules[0], nn.Linear): assert all(isinstance(m, nn.Linear) for m in modules) channels_dim = -1 + elif isinstance(module, Conv2dAsLinear): + channels_dim = -1 else: assert all(isinstance(m, nn.Conv2d) for m in modules) channels_dim = 1 @@ -96,9 +102,20 @@ def calibrate_diffusion_block_low_rank_branch( # noqa: C901 if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key): config_wgts = config.extra_wgts quantizer = DiffusionWeightQuantizer(config_wgts, develop_dtype=config.develop_dtype, key=module_key) + if isinstance(module, Conv2dAsLinear): + logger.info(f"[low_rank_branch]Conv2dAsLinear module_name: {module_name}, module_key: {module_key}, quantizer.is_enabled: {quantizer.is_enabled()}, quantizer.is_enabled_low_rank: {quantizer.is_enabled_low_rank()}") if quantizer.is_enabled() and quantizer.is_enabled_low_rank(): if isinstance(module, nn.Conv2d): assert module.weight.shape[2:].numel() + elif isinstance(module, Conv2dAsLinear): + assert len(module_names) == 1 + assert len(modules) == 1 + module_name = join_name(module_name, "linear") + eval_name = module_name + module_names = [module_name] + eval_module = modules[0].linear + modules = [modules[0].linear] + module = module.linear else: assert isinstance(module, nn.Linear) if module_name not in branch_state_dict: @@ -106,7 +123,7 @@ def calibrate_diffusion_block_low_rank_branch( # noqa: C901 tools.logging.Formatter.indent_inc() branch_state_dict[module_name] = quantizer.calibrate_low_rank( input_quantizer=DiffusionActivationQuantizer( - config.ipts, key=module_key, channels_dim=channels_dim + config.ipts, key=module_key, channels_dim=channels_dim, develop_dtype=config.develop_dtype ), modules=modules, inputs=layer_cache[module_name].inputs if layer_cache else None, @@ -140,11 +157,16 @@ def calibrate_diffusion_block_low_rank_branch( # noqa: C901 module.weight.data.sub_(branch.get_effective_weight().view(module.weight.data.shape)) branch.as_hook().register(module) else: + # if isinstance(module, Conv2dAsLinear): + # module = module.linear module.weight.data.sub_(shared_branch.get_effective_weight().view(module.weight.data.shape)) shared_branch.as_hook().register(module) del shared_branch gc.collect() torch.cuda.empty_cache() + if layer.name == "up_blocks.2": + torch.cuda.memory._dump_snapshot("/data/dongd/low_rank_calib_upblocks.2_snapshot.pickle") + torch.cuda.memory._record_memory_history(enabled=None) @torch.inference_mode() @@ -191,6 +213,10 @@ def update_diffusion_block_weight_quantizer_state_dict( if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key): config_wgts = config.extra_wgts quantizer = DiffusionWeightQuantizer(config_wgts, develop_dtype=config.develop_dtype, key=module_key) + if isinstance(module, Conv2dAsLinear): + logger.info(f"[update_quantizer]Conv2dAsLinear module_name: {module_name}, module_key: {module_key}, quantizer.is_enabled: {quantizer.is_enabled()}") + module_name = join_name(module_name, "linear") + module = module.linear if quantizer.is_enabled(): if module_name not in quantizer_state_dict: logger.debug("- Calibrating %s.weight quantizer", module_name) @@ -247,6 +273,10 @@ def quantize_diffusion_block_weights( tools.logging.Formatter.indent_inc() for module_key, module_name, module, _, _ in layer.named_key_modules(): + if isinstance(module, Conv2dAsLinear): + logger.info(f"[quantize_weight]Conv2dAsLinear module_name: {module_name}, module_key: {module_key}") + module_name = join_name(module_name, "linear") + module = module.linear if module_name in quantizer_state_dict: param_name = f"{module_name}.weight" logger.debug("- Quantizing %s", param_name) @@ -257,6 +287,7 @@ def quantize_diffusion_block_weights( logger.debug(" + group_shape: %s", str(config_wgts.group_shapes)) logger.debug(" + scale_dtype: %s", str(config_wgts.scale_dtypes)) quantizer = DiffusionWeightQuantizer(config_wgts, develop_dtype=config.develop_dtype, key=module_key) + logger.info(f"[quantize_weight] quantizer.is_enabled: {quantizer.is_enabled()}") quantizer.load_state_dict(quantizer_state_dict[module_name], device=module.weight.device) result = quantizer.quantize( module.weight.data, diff --git a/deepcompressor/backend/nunchaku/sdxl.py b/deepcompressor/backend/nunchaku/sdxl.py index d3b1d58..6e77ec4 100644 --- a/deepcompressor/backend/nunchaku/sdxl.py +++ b/deepcompressor/backend/nunchaku/sdxl.py @@ -1,10 +1,11 @@ import torch -from .common import convert_to_nunchaku_transformer_block_state_dict, update_state_dict +from .common import convert_to_nunchaku_transformer_block_state_dict, convert_to_nunchaku_w4x4y16_linear_state_dict, update_state_dict -def _get_sdxl_transformer_block_names(state_dict): +def _get_sdxl_block_names(state_dict): transformer_block_names: set[str] = set() + conv2d_as_linear_names: set[str] = set() other: dict[str, torch.Tensor] = {} for param_name in state_dict.keys(): if ".transformer_blocks." in param_name: @@ -14,11 +15,19 @@ def _get_sdxl_transformer_block_names(state_dict): transformer_block_names.add(".".join(param_name.split(".")[:5])) else: raise ValueError(f"Unknown block name: {param_name}") + elif ".resnets." in param_name and ".conv" in param_name and ".linear" in param_name: + if param_name.startswith("up_blocks") or param_name.startswith("down_blocks"): + conv2d_as_linear_names.add(".".join(param_name.split(".")[:6])) + elif param_name.startswith("mid_block"): + conv2d_as_linear_names.add(".".join(param_name.split(".")[:5])) + else: + raise ValueError(f"Unknown block name: {param_name}") else: other[param_name] = state_dict[param_name] # all the numbers in sdxl state dict are single-digit, so there's no need to convert to int for sorting. transformer_block_names = sorted(transformer_block_names, key=lambda x: tuple(x.split("."))) - return transformer_block_names, other + conv2d_as_linear_names = sorted(conv2d_as_linear_names, key=lambda x: tuple(x.split("."))) + return transformer_block_names, conv2d_as_linear_names, other def convert_to_nunchaku_sdxl_state_dicts( @@ -28,10 +37,10 @@ def convert_to_nunchaku_sdxl_state_dicts( branch_dict: dict[str, torch.Tensor], float_point: bool = False, ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - block_names, other = _get_sdxl_transformer_block_names(model_dict) - print(f"Converting {len(block_names)} transformer blocks...") + transformer_block_names, conv2d_as_linear_names, other = _get_sdxl_block_names(model_dict) + print(f"Converting {len(transformer_block_names)} transformer blocks...") converted: dict[str, torch.Tensor] = {} - for block_name in block_names: + for block_name in transformer_block_names: d = convert_to_nunchaku_transformer_block_state_dict( state_dict=model_dict, scale_dict=scale_dict, @@ -78,4 +87,15 @@ def convert_to_nunchaku_sdxl_state_dicts( float_point=float_point, ) update_state_dict(converted, d, prefix=block_name,) + + for _name in conv2d_as_linear_names: + conv2d_as_linear_state_dict = convert_to_nunchaku_w4x4y16_linear_state_dict( + weight=model_dict[f"{_name}.weight"], + scale=scale_dict[f"{_name}.weight.scale.0"], + bias=model_dict[f"{_name}.bias"], + smooth=smooth_dict[_name], + lora=(branch_dict[_name]["a.weight"], branch_dict[_name]["b.weight"]), + ) + update_state_dict(converted, conv2d_as_linear_state_dict, prefix=_name) + return converted, other diff --git a/deepcompressor/calib/smooth.py b/deepcompressor/calib/smooth.py index 876daa1..0f9de8f 100644 --- a/deepcompressor/calib/smooth.py +++ b/deepcompressor/calib/smooth.py @@ -64,6 +64,8 @@ def process(self, tensor: torch.Tensor) -> torch.Tensor: `torch.Tensor`: The smoothed tensor. """ + # logger = tools.logging.getLogger(__name__) + # logger.info(f"- [ActivationSmoother] tensor.dtype={tensor.dtype}, develop_dtype={self.develop_dtype}, smooth_scale_dtype={self.smooth_scale.dtype}") device, dtype = tensor.device, tensor.dtype if self.develop_dtype is None: self.develop_dtype = dtype @@ -511,7 +513,7 @@ def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSING_TYPE = x = x.to(dtype=self.develop_dtype) if dtype != self.develop_dtype else x.clone() x = x.div_(scale) x = self.x_quantizer.quantize( - x, channels_dim=channels_dim, default_dtype=dtype, develop_dtype=self.develop_dtype + x, channels_dim=channels_dim, default_dtype=dtype, develop_dtype=self.develop_dtype, in_place=True ).data x = x.mul_(scale).to(dtype=dtype) return x.view(shape) diff --git a/deepcompressor/data/range.py b/deepcompressor/data/range.py index d0f0d5d..9e559a2 100644 --- a/deepcompressor/data/range.py +++ b/deepcompressor/data/range.py @@ -7,6 +7,8 @@ import torch +from deepcompressor.data.utils.range import absmax + from .dtype import QuantDataType from .zero import ZeroPointDomain @@ -310,9 +312,11 @@ def measure( # noqa: C901 # region step 1: determine the value range (i.e., vmax and vmin) if zero_domain is None: vmin = None - vmax = tensors[0].abs().amax(dim=reduced, keepdim=True) + vmax = absmax(tensors[0], reduced) + # vmax = tensors[0].abs().amax(dim=reduced, keepdim=True) for tensor in tensors[1:]: - vmax = torch.maximum(vmax, tensor.abs().amax(dim=reduced, keepdim=True).to(vmax.device)) + # vmax = torch.maximum(vmax, tensor.abs().amax(dim=reduced, keepdim=True).to(vmax.device)) + vmax = torch.maximum(vmax, absmax(tensor, reduced).to(vmax.device)) else: vmax = tensors[0].amax(dim=reduced, keepdim=True) for tensor in tensors[1:]: diff --git a/deepcompressor/data/utils/range.py b/deepcompressor/data/utils/range.py new file mode 100644 index 0000000..2ccdfab --- /dev/null +++ b/deepcompressor/data/utils/range.py @@ -0,0 +1,7 @@ +import torch + +def absmax(tensor: torch.Tensor, reduce_dim: int): + # call amin()/amax() before calling abs(), in order to optimize memory usage. + t_minabs = tensor.amin(dim=reduce_dim, keepdim=True).abs() + t_maxabs = tensor.amax(dim=reduce_dim, keepdim=True).abs() + return torch.maximum(t_minabs, t_maxabs) \ No newline at end of file diff --git a/deepcompressor/dataset/cache.py b/deepcompressor/dataset/cache.py index 3ba9bd8..f7ab387 100644 --- a/deepcompressor/dataset/cache.py +++ b/deepcompressor/dataset/cache.py @@ -15,10 +15,12 @@ import torch.utils.hooks from tqdm import tqdm +from deepcompressor.nn.patch.conv import Conv2dAsLinear + from ..data.cache import IOTensorsCache, ModuleForwardInput, TensorCache from ..data.utils.reshape import ConvInputReshapeFn, ConvOutputReshapedFn, LinearReshapeFn from ..utils import tools -from ..utils.common import tree_copy_with_ref, tree_map +from ..utils.common import join_name, tree_copy_with_ref, tree_map from ..utils.hooks import EarlyStopException, EarlyStopHook, Hook from .action import CacheAction @@ -260,19 +262,35 @@ def _iter_layer_activations( # noqa: C901 needs_inputs = needs_inputs_fn(module_name, module) needs_outputs = needs_outputs_fn(module_name, module) if needs_inputs or needs_outputs: - module_names[layer_name].append(module_name) - cache.setdefault(layer_name, {})[module_name] = self._init_cache(module_name, module) - hook_args.setdefault(layer_name, []).append((module_name, module, needs_inputs, needs_outputs)) - info_hooks.extend( - action.register( - name=module_name, - module=module, - cache=cache[layer_name][module_name], - info_mode=True, - needs_inputs=needs_inputs, - needs_outputs=needs_outputs, + if isinstance(module, Conv2dAsLinear): + linear = join_name(module_name, "linear") + module_names[layer_name].append(linear) + cache.setdefault(layer_name, {})[linear] = self._init_cache(linear, module.linear) + hook_args.setdefault(layer_name, []).append((linear, module.linear, needs_inputs, needs_outputs)) + info_hooks.extend( + action.register( + name=linear, + module=module.linear, + cache=cache[layer_name][linear], + info_mode=True, + needs_inputs=needs_inputs, + needs_outputs=needs_outputs, + ) + ) + else: + module_names[layer_name].append(module_name) + cache.setdefault(layer_name, {})[module_name] = self._init_cache(module_name, module) + hook_args.setdefault(layer_name, []).append((module_name, module, needs_inputs, needs_outputs)) + info_hooks.extend( + action.register( + name=module_name, + module=module, + cache=cache[layer_name][module_name], + info_mode=True, + needs_inputs=needs_inputs, + needs_outputs=needs_outputs, + ) ) - ) if len(cache) == 0: return if layers is not None: @@ -311,7 +329,7 @@ def _iter_layer_activations( # noqa: C901 if early_stop_module is not None: forward_hooks.append(early_stop_module.register_forward_hook(EarlyStopHook())) with torch.inference_mode(): - device = "cuda" if torch.cuda.is_available() else "cpu" + device = next(model.parameters()).device if torch.cuda.is_available() else "cpu" tbar = tqdm( desc="collecting acts info", leave=False, diff --git a/deepcompressor/nn/patch/conv.py b/deepcompressor/nn/patch/conv.py index d094b4f..93f6e0e 100644 --- a/deepcompressor/nn/patch/conv.py +++ b/deepcompressor/nn/patch/conv.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch.nn.common_types import _size_2_t -__all__ = ["ConcatConv2d", "ShiftedConv2d"] +__all__ = ["ConcatConv2d", "ShiftedConv2d", "Conv2dAsLinear"] class ConcatConv2d(nn.Module): @@ -180,3 +180,66 @@ def from_conv2d(conv: nn.Conv2d, shift: float | torch.Tensor) -> "ShiftedConv2d" else: shifted.conv.bias.data.copy_(-shifted_bias.to(dtype)) return shifted + + +class Conv2dAsLinear(nn.Module): + def __init__(self, conv: nn.Conv2d): + super().__init__() + assert isinstance(conv, nn.Conv2d) + self.in_channels = conv.in_channels + self.out_channels = conv.out_channels + self.kernel_size = conv.kernel_size + self.stride = conv.stride + self.padding = conv.padding + self.bias_flag = conv.bias is not None + + # ------------------------------- + # unfold Conv2d's weight and bias to Linear + # Conv2d.weight.shape = [out_channels, in_channels, kH, kW] + # Linear.weight.shape = [out_features, in_features] + # in_features = in_channels * kH * kW + # out_features = out_channels + # ------------------------------- + kH, kW = self.kernel_size + self.linear = nn.Linear( + in_features=self.in_channels * kH * kW, + out_features=self.out_channels, + bias=self.bias_flag, + device=conv.weight.device, + dtype=conv.weight.dtype + ) + + # copy parameters + with torch.no_grad(): + self.linear.weight.copy_(conv.weight.view(self.out_channels, -1)) + if self.bias_flag: + self.linear.bias.copy_(conv.bias) + + def forward(self, x: torch.Tensor): + # input: x.shape = [N, C_in, H, W] + # output: y.shape = [N, C_out, H_out, W_out] + N, C, H, W = x.shape + + # 1. unfold + # after unfolding: x_unf.shape = [N, C*kH*kW, L],其中 L = H_out * W_out + x_unf = F.unfold( + x, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding + ) + + # 2. transpose: [N, L, C*kH*kW] + x_unf = x_unf.transpose(1, 2) + + # 3. pass through linear layer + # output shape: [N, L, C_out] + y: torch.Tensor = self.linear(x_unf) + + # 4. transpose back and reshape + # output shape: [N, C_out, H_out, W_out] + H_out = (H + 2 * self.padding[0] - self.kernel_size[0]) // self.stride[0] + 1 + W_out = (W + 2 * self.padding[1] - self.kernel_size[1]) // self.stride[1] + 1 + y = y.transpose(1, 2).reshape(N, self.out_channels, H_out, W_out) + + return y \ No newline at end of file diff --git a/deepcompressor/quantizer/impl/base.py b/deepcompressor/quantizer/impl/base.py index 2b26f03..8ace10d 100644 --- a/deepcompressor/quantizer/impl/base.py +++ b/deepcompressor/quantizer/impl/base.py @@ -64,6 +64,7 @@ def quantize( return_with_quant: bool = False, default_dtype: torch.dtype | None = torch.float16, develop_dtype: torch.dtype = torch.float32, + in_place: bool = False, **kwargs, ) -> QuantTensor: """Quantize a floating point tensor. @@ -93,6 +94,8 @@ def quantize( The default dtype for scale. develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The develop dtype. + in_place (`bool`): + Whether to quantize the tensor in place. **kwargs: Other keyword arguments for the quantization kernel. For example, ``inputs`` for the input tensors in GPTQ kernel, @@ -121,6 +124,7 @@ def quantize( return_with_quant=return_with_quant, default_dtype=default_dtype or tensor.dtype, develop_dtype=develop_dtype, + in_place=in_place, **kwargs, ) if result.data is not None: @@ -147,6 +151,7 @@ def _quantize( # noqa: C901 return_with_quant: bool = False, default_dtype: torch.dtype = torch.float16, develop_dtype: torch.dtype = torch.float32, + in_place: bool = False, **kwargs, ) -> QuantTensor: """Quantize a floating point tensor. @@ -174,6 +179,8 @@ def _quantize( # noqa: C901 The default dtype for scale. develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The develop dtype. + in_place (`bool`): + Whether to quantize the tensor in place. **kwargs: Other keyword arguments for the quantization kernel. For example, ``inputs`` for the input tensors in GPTQ kernel, @@ -207,7 +214,13 @@ def _quantize( # noqa: C901 # endregion # region compute and quantize the scales and zero point for quantization quant_scale = QuantScale() - develop_tensor = tensor.to(dtype=develop_dtype) if dtype != develop_dtype else tensor.clone() + if in_place: + assert dtype == develop_dtype + develop_tensor = tensor + else: + if develop_dtype != torch.bfloat16: + raise RuntimeError(f"develop_dtype: {develop_dtype}") + develop_tensor = tensor.to(dtype=develop_dtype) if dtype != develop_dtype else tensor.clone() for step, (step_info, step_scale, step_dynamic_range) in enumerate( zip(self.info.steps, scale, dynamic_range, strict=True) ): @@ -247,8 +260,10 @@ def _quantize( # noqa: C901 round_delta=round_delta, **kwargs, ) - assert not develop_tensor.isnan().any(), "Quantized tensor contains NaN." - assert not develop_tensor.isinf().any(), "Quantized tensor contains Inf." + # assert not develop_tensor.isnan().any(), "Quantized tensor contains NaN." + # assert not develop_tensor.isinf().any(), "Quantized tensor contains Inf." + assert not _yes(develop_tensor, torch.isnan), "Quantized tensor contains NaN." + assert not _yes(develop_tensor, torch.isinf), "Quantized tensor contains Inf." # endregion # region update the quantized tensor quantized = None @@ -321,3 +336,14 @@ def update( config, tensor_shape, default_dtype, quant_range=quant_range, range_bound=range_bound ) return self.info + + +def _yes(tensor: torch.Tensor, judge_func: tp.Callable): + chunk_size = 50_000_000 + numel = tensor.numel() + flat = tensor.flatten() + for i in range(0, numel, chunk_size): + end = min(i + chunk_size, numel) + if judge_func(flat[i:end]).any(): + return True + return False \ No newline at end of file diff --git a/deepcompressor/quantizer/kernel/devices.py b/deepcompressor/quantizer/kernel/devices.py new file mode 100644 index 0000000..113ee40 --- /dev/null +++ b/deepcompressor/quantizer/kernel/devices.py @@ -0,0 +1,181 @@ + +import functools +import gc + +import torch + +from deepcompressor.data.cache import TensorCache +from deepcompressor.utils import tools + + +def _backup_deivces(): + return ["cuda:5", "cuda:6", "cuda:7"] + + +def try_all_devices(forced_dtype=None): + logger = tools.logging.getLogger(__name__) + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + # args_tensors = [] + # for arg in args: + # if isinstance(arg, torch.Tensor): + # args_tensors.append(arg) + # elif isinstance(arg, TensorCache): + # for _d in arg.data: + # args_tensors.append(_d) + # if len(args_tensors) > 0: + # _arg_device = args_tensors[0].device + # _arg_dtype = args_tensors[0].dtype + # assert all(t.device == _arg_device for t in args_tensors) + # assert all(t.dtype == _arg_dtype for t in args_tensors) + + # kwargs_tensors = [] + # for k, v in kwargs.items(): + # if isinstance(v, torch.Tensor): + # kwargs_tensors.append(v) + # elif isinstance(v, TensorCache): + # for _d in v.data: + # kwargs_tensors.append(_d) + # if len(kwargs_tensors) > 0: + # _kwarg_device = kwargs_tensors[0].device + # _kwarg_dtype = kwargs_tensors[0].dtype + # assert all(t.device == _kwarg_device for t in kwargs_tensors), f"{list(kwargs.keys())}, {[t.device for t in kwargs_tensors]}" + # assert all(t.dtype == _kwarg_dtype for t in kwargs_tensors), f"{list(kwargs.keys())}, {[t.dtype for t in kwargs_tensors]}" + + # if len(args_tensors) == 0 and len(kwargs_tensors) == 0: + # # no tensor in args and kwargs + # return func(*args, **kwargs) + # if len(args_tensors) > 0 and len(kwargs_tensors) > 0: + # assert _arg_device == _kwarg_device + # assert _arg_dtype == _kwarg_dtype + + # orig_device = args_tensors[0].device if len(args_tensors) > 0 else kwargs_tensors[0].device + # orig_dtype = args_tensors[0].dtype if len(args_tensors) > 0 else kwargs_tensors[0].dtype + + assert len(args) >= 1 and isinstance(args[0], torch.Tensor) + orig_device = args[0].device + orig_dtype = args[0].dtype + + devices = [orig_device, *_backup_deivces()] + to_dtype = orig_dtype if forced_dtype is None else forced_dtype + + tried_devices = [] + class _PlaceHolder: + pass + ret = _PlaceHolder() + + for idx, dev in enumerate(devices): + try: + # if idx == 0: + # prev_args = args + # prev_kwargs = kwargs + moved_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + moved_args.append(arg.to(device=dev, dtype=to_dtype)) + del arg + # elif isinstance(arg, TensorCache): + # _arg = TensorCache( + # data=[_d.to(device=dev, dtype=to_dtype) for _d in arg.data], + # channels_dim=arg.channels_dim, + # reshape=arg.reshape, + # num_cached=arg.num_cached, + # num_total=arg.num_total, + # num_samples=arg.num_samples, + # orig_device=arg.orig_device + # ) + # moved_args.append(_arg) + else: + moved_args.append(arg) + moved_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + moved_kwargs[k] = v.to(device=dev, dtype=to_dtype) + del v + # elif isinstance(v, TensorCache): + # moved_kwargs[k] = TensorCache( + # data=[_d.to(device=dev, dtype=to_dtype) for _d in v.data], + # channels_dim=v.channels_dim, + # reshape=v.reshape, + # num_cached=v.num_cached, + # num_total=v.num_total, + # num_samples=v.num_samples, + # orig_device=v.orig_device + # ) + else: + moved_kwargs[k] = v + # gc.collect() + # torch.cuda.empty_cache() + # torch.cuda.synchronize() + ret = func(*moved_args, **moved_kwargs) + break + except (torch.OutOfMemoryError, RuntimeError) as e: + if isinstance(e, torch.OutOfMemoryError) or "CUDA out of memory" in str(e): + logger.info(f"** OOM OCCURRED when calling {func.__name__} on device {dev}") + tried_devices.append(dev) + last_error = e + # prev_args = moved_args + # prev_kwargs = moved_kwargs + else: + raise e + if isinstance(ret, _PlaceHolder): + logger.info(f"** OOM OCCURRED on these devices {tried_devices}") + raise last_error + else: + if isinstance(ret, tuple): + moved_ret = [] + for t in tuple: + if isinstance(t, torch.Tensor): + moved_ret.append(t.to(device=orig_device, dtype=orig_dtype)) + else: + moved_ret.append(t) + return tuple(moved_ret) + elif isinstance(ret, torch.Tensor): + return ret.to(device=orig_device, dtype=orig_dtype) + else: + return ret + + return wrapper + + return decorator + + +#### for demo #### +def repeat(num_times=None): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + for _ in range(num_times): + func(*args, **kwargs) + return wrapper + return decorator +#### + +if __name__ == "__main__": + + @repeat(num_times=3) + def hello(name): + print(f"hello {name}") + + hello("alice") + + + @try_all_devices() + def func1(x,*, y=4): + print(f"in func1, x={x}, y={y}, x.device={x.device}, x.dtype={x.dtype}") + if x.device == torch.device("cuda:4"): + raise RuntimeError("CUDA out of memory fake") + return x + y + + + t1 = torch.Tensor([1,2]).to(device="cuda:4", dtype=torch.bfloat16) + # t2 = torch.Tensor([3,4]).to("cuda:4") + + t3 = func1(t1, y=5) + print(t3) + print(t3.dtype) + + + \ No newline at end of file diff --git a/deepcompressor/quantizer/kernel/gptq.py b/deepcompressor/quantizer/kernel/gptq.py index 274a9ba..176d737 100644 --- a/deepcompressor/quantizer/kernel/gptq.py +++ b/deepcompressor/quantizer/kernel/gptq.py @@ -8,6 +8,8 @@ import torch from omniconfig import configclass +from deepcompressor.quantizer.kernel.devices import try_all_devices + from ...data.cache import TensorCache from ...data.dtype import QuantDataType from ...data.range import QuantRange, RangeBound @@ -126,6 +128,7 @@ def quantize( @torch.no_grad() +@try_all_devices(forced_dtype=torch.float32) def gptq_quantize( # noqa: C901 tensor: torch.Tensor, *, @@ -214,18 +217,19 @@ def gptq_quantize( # noqa: C901 # endregion # region step 5: get the inverse of the Hessian matrix stable_inv, num_inv_tries = False, 0 + logger = tools.logging.getLogger(f"{__name__}.GPTQ") while (not stable_inv) and num_inv_tries < gptq_config.num_inv_tries: num_inv_tries += 1 try: hessian_inv = torch.linalg.cholesky(hessian) hessian_inv = torch.cholesky_inverse(hessian_inv) hessian_inv = torch.linalg.cholesky(hessian_inv, upper=True) - except RuntimeError: + except RuntimeError as err: + logger.info(f" - Hessian RuntimeError: {str(err)}") hessian_diag += (gptq_config.damp_percentage * 0.1) * hessian_diag_mean continue stable_inv = True if num_inv_tries > 1: - logger = tools.logging.getLogger(f"{__name__}.GPTQ") logger.debug(" - Hessian is not stable %s %d tries.", "until" if stable_inv else "after", num_inv_tries) assert not hessian_inv.isinf().any(), "Inverse of Hessian matrix contains Inf." assert not hessian_inv.isnan().any(), "Inverse of Hessian matrix contains NaN." diff --git a/deepcompressor/quantizer/kernel/rtn.py b/deepcompressor/quantizer/kernel/rtn.py index 587b2f7..e8680e2 100644 --- a/deepcompressor/quantizer/kernel/rtn.py +++ b/deepcompressor/quantizer/kernel/rtn.py @@ -3,6 +3,8 @@ import torch +from deepcompressor.quantizer.kernel.devices import try_all_devices + from ...data.dtype import QuantDataType from ...data.range import QuantRange from ...data.zero import ZeroPointDomain @@ -65,6 +67,7 @@ def quantize( ) +@try_all_devices() def rtn_quantize( tensor: torch.Tensor, *, @@ -104,7 +107,7 @@ def rtn_quantize( round_delta = round_delta.view(view_shape) if round_delta is not None else None if zero_domain == ZeroPointDomain.PostScale: qtensor = qtensor.add_(zero) - qtensor = qtensor.div(scale) + qtensor = qtensor.div_(scale) if zero_domain == ZeroPointDomain.PreScale: qtensor = qtensor.add_(zero) qtensor = simple_quantize( diff --git a/deepcompressor/quantizer/processor.py b/deepcompressor/quantizer/processor.py index 36a098b..e332b49 100644 --- a/deepcompressor/quantizer/processor.py +++ b/deepcompressor/quantizer/processor.py @@ -6,6 +6,9 @@ import torch +from deepcompressor.data.common import TensorType +from deepcompressor.utils import tools + from ..data.range import DynamicRange, QuantRange, RangeBound from ..data.tensor import QuantTensor from ..nn.patch.lowrank import LowRankBranch @@ -88,7 +91,11 @@ def get_output_packager(self) -> BaseOutputPackager | None: return self.output_packager def process(self, tensor: torch.Tensor) -> torch.Tensor: - return self.quantize(tensor).data + _in_place = hasattr(self, 'tensor_type') and self.tensor_type == TensorType.Inputs and self.key == "up_resblock_conv" + if _in_place: + logger = tools.logging.getLogger(__name__) + logger.info(f"** quantizer process key:{self.key}, type:{type(self)}, in_place:{_in_place}") + return self.quantize(tensor, in_place=_in_place).data def quantize( self, @@ -111,6 +118,7 @@ def quantize( quant_range: QuantRange | None | _MISSING_TYPE = MISSING, default_dtype: torch.dtype | None | _MISSING_TYPE = MISSING, develop_dtype: torch.dtype | _MISSING_TYPE = MISSING, + in_place: bool = False, **kwargs, ) -> QuantTensor: """Quantize a tensor. @@ -143,6 +151,8 @@ def quantize( The default scale dtype. develop_dtype (`torch.dtype` or `_MISSING_TYPE`, *optional*, defaults to `MISSING`): The quantization development dtype. + in_place (`bool`): + Whether to quantize the tensor in place. **kwargs: Other keyword arguments for the quantization kernel. For example, ``inputs`` for the input tensors in GPTQ kernel, @@ -179,6 +189,7 @@ def quantize( return_with_quant=return_with_quant, default_dtype=default_dtype, develop_dtype=develop_dtype, + in_place=in_place, **kwargs, ) diff --git a/deepcompressor/utils/hooks/branch.py b/deepcompressor/utils/hooks/branch.py index 274e197..e823359 100644 --- a/deepcompressor/utils/hooks/branch.py +++ b/deepcompressor/utils/hooks/branch.py @@ -23,7 +23,7 @@ def __init__( ): super().__init__(pre=True, post=True, input_packager=input_packager, output_packager=output_packager) self.branch = branch - self.tensor = None + # self.tensor = None def pre_forward( self, module: nn.Module, input_args: tuple[torch.Tensor, ...], input_kwargs: dict[str, tp.Any] @@ -37,7 +37,11 @@ def pre_forward( """ tensors = self.input_packager.unpack(module, input_args, input_kwargs) assert len(tensors) == 1, "BranchHook only supports single input tensor" - self.tensor = next(iter(tensors.values())) + # self.tensor = next(iter(tensors.values())) + _input_tensor = next(iter(tensors.values())) + # self.branch_output = self.branch(self.tensor) + if self.branch is not None: + self.branch_output = self.branch(_input_tensor) return None def post_forward( @@ -59,6 +63,7 @@ def post_forward( assert len(output_tensors) == 1, "LoRAHook only supports single output tensor" output_key, output_tensor = next(iter(output_tensors.items())) if self.branch is not None: - output_tensor = output_tensor + self.branch(self.tensor) - self.tensor = None + # output_tensor = output_tensor + self.branch(self.tensor) + output_tensor = output_tensor + self.branch_output + # self.tensor = None return self.output_packager.repack({output_key: output_tensor}, module, input_args, input_kwargs, output) diff --git a/examples/diffusion/configs/collect/qdiff.yaml.dongd b/examples/diffusion/configs/collect/qdiff.yaml.dongd new file mode 100644 index 0000000..edf0cf2 --- /dev/null +++ b/examples/diffusion/configs/collect/qdiff.yaml.dongd @@ -0,0 +1,5 @@ +collect: + root: /data/dongd/deepcompr/datasets + dataset_name: qdiff + data_path: examples/diffusion/prompts/qdiff.yaml + num_samples: 128 diff --git a/examples/diffusion/configs/model/flux.1-dev-custom.yaml b/examples/diffusion/configs/model/flux.1-dev-custom.yaml new file mode 100644 index 0000000..bfac75f --- /dev/null +++ b/examples/diffusion/configs/model/flux.1-dev-custom.yaml @@ -0,0 +1,70 @@ +cache: + root: /data/dongd/deepcompr/runs +output: + root: /data/dongd/deepcompr/runs +pipeline: + name: flux.1-dev-custom + device: "cuda:4" + dtype: torch.bfloat16 + path: /data/dongd/deepcompr/original-model/flux/f0a56ab51074043f7fb95e86f5246031e2424ea1a54b64a8948f6e612fb52fd6 +eval: + num_steps: 50 + guidance_scale: 3.5 + protocol: fmeuler{num_steps}-g{guidance_scale} + batch_size_per_gpu: 2 + num_gpus: 1 + num_samples: 100 + benchmarks: + - "/data/dongd/datasets/MJHQ30K.json" +quant: + calib: + batch_size: 16 + path: /data/dongd/deepcompr/datasets/{dtype}/{model}/{protocol}/{data}/s128 # s128 is for qdiff calib data samples + wgts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + low_rank: + sample_batch_size: 16 + sample_size: -1 + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - transformer_proj_in + - transformer_proj_out + - down_sample + - up_sample + ipts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - down_sample + - up_sample + opts: + calib_range: + element_batch_size: 64 + sample_batch_size: 16 + element_size: 512 + sample_size: -1 + smooth: + proj: + element_batch_size: -1 + sample_batch_size: 16 + element_size: -1 + sample_size: -1 + attn: + sample_batch_size: 16 + sample_size: -1 diff --git a/examples/diffusion/configs/model/sdxl-turbo.20250906_1029.yaml b/examples/diffusion/configs/model/sdxl-turbo.20250906_1029.yaml new file mode 100644 index 0000000..f4f42a6 --- /dev/null +++ b/examples/diffusion/configs/model/sdxl-turbo.20250906_1029.yaml @@ -0,0 +1,50 @@ +cache: + root: /data/dongd/deepcompr/runs +output: + root: /data/dongd/deepcompr/runs +pipeline: + name: sdxl-turbo + dtype: torch.bfloat16 +eval: + num_steps: 4 + guidance_scale: 0 + protocol: eulera{num_steps}-g{guidance_scale} + batch_size_per_gpu: 2 + num_gpus: 2 + num_samples: 100 + benchmarks: + - "MJHQ" +quant: + calib: + batch_size: 2 + combine: false + num_samples: 128 + path: /data/dongd/deepcompr/datasets/{dtype}/{model}/{protocol}/{data}/s128 # s128 is for qdiff calib data samples + wgts: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - resblock_conv + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample + ipts: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - resblock_conv + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample diff --git a/examples/diffusion/configs/model/sdxl-turbo.20250909_0510.yaml b/examples/diffusion/configs/model/sdxl-turbo.20250909_0510.yaml new file mode 100644 index 0000000..6734ad1 --- /dev/null +++ b/examples/diffusion/configs/model/sdxl-turbo.20250909_0510.yaml @@ -0,0 +1,50 @@ +cache: + root: /data/dongd/deepcompr/runs +output: + root: /data/dongd/deepcompr/runs +pipeline: + name: sdxl-turbo + dtype: torch.bfloat16 +eval: + num_steps: 4 + guidance_scale: 0 + protocol: eulera{num_steps}-g{guidance_scale} + batch_size_per_gpu: 2 + num_gpus: 1 + num_samples: 100 + benchmarks: + - "/data/dongd/datasets/MJHQ30K.json" +quant: + calib: + batch_size: 2 + combine: false + num_samples: 128 + path: /data/dongd/deepcompr/datasets/{dtype}/{model}/{protocol}/{data}/s128 # s128 is for qdiff calib data samples + wgts: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - resblock_conv + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample + ipts: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - resblock_conv + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample diff --git a/examples/diffusion/configs/model/sdxl-turbo.yaml.my b/examples/diffusion/configs/model/sdxl-turbo.yaml.my new file mode 100644 index 0000000..6734ad1 --- /dev/null +++ b/examples/diffusion/configs/model/sdxl-turbo.yaml.my @@ -0,0 +1,50 @@ +cache: + root: /data/dongd/deepcompr/runs +output: + root: /data/dongd/deepcompr/runs +pipeline: + name: sdxl-turbo + dtype: torch.bfloat16 +eval: + num_steps: 4 + guidance_scale: 0 + protocol: eulera{num_steps}-g{guidance_scale} + batch_size_per_gpu: 2 + num_gpus: 1 + num_samples: 100 + benchmarks: + - "/data/dongd/datasets/MJHQ30K.json" +quant: + calib: + batch_size: 2 + combine: false + num_samples: 128 + path: /data/dongd/deepcompr/datasets/{dtype}/{model}/{protocol}/{data}/s128 # s128 is for qdiff calib data samples + wgts: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - resblock_conv + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample + ipts: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - resblock_conv + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample diff --git a/examples/diffusion/configs/model/sdxl.yaml b/examples/diffusion/configs/model/sdxl.yaml new file mode 100644 index 0000000..bcb7806 --- /dev/null +++ b/examples/diffusion/configs/model/sdxl.yaml @@ -0,0 +1,51 @@ +cache: + root: /data/dongd/deepcompr/runs +output: + root: /data/dongd/deepcompr/runs +pipeline: + name: sdxl + dtype: torch.bfloat16 + device: "cuda:4" +eval: + num_steps: 50 + guidance_scale: 5.0 + protocol: EulerDiscrete{num_steps}-g{guidance_scale} + batch_size_per_gpu: 2 + num_gpus: 1 + num_samples: 100 + benchmarks: + - "/data/dongd/datasets/MJHQ30K.json" +quant: + calib: + batch_size: 2 + combine: false + num_samples: 128 + path: /data/dongd/deepcompr/datasets/{dtype}/{model}/{protocol}/{data}/s128 # s128 is for qdiff calib data samples + wgts: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - resblock_conv + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample + ipts: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - resblock_conv + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample diff --git a/examples/diffusion/configs/model/sdxl_conv_no_skip.yaml b/examples/diffusion/configs/model/sdxl_conv_no_skip.yaml new file mode 100644 index 0000000..4c06c17 --- /dev/null +++ b/examples/diffusion/configs/model/sdxl_conv_no_skip.yaml @@ -0,0 +1,71 @@ +cache: + root: /data/dongd/deepcompr/runs +output: + root: /data/dongd/deepcompr/runs +pipeline: + name: sdxl + dtype: torch.bfloat16 + device: "cuda:4" +eval: + num_steps: 50 + guidance_scale: 5.0 + protocol: EulerDiscrete{num_steps}-g{guidance_scale} + batch_size_per_gpu: 2 + num_gpus: 1 + num_samples: 100 + benchmarks: + - "/data/dongd/datasets/MJHQ30K.json" +quant: + develop_dtype: torch.bfloat16 + calib: + batch_size: 2 + combine: false + num_samples: 128 + path: /data/dongd/deepcompr/datasets/{dtype}/{model}/{protocol}/{data}/s128 # s128 is for qdiff calib data samples + wgts: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample + low_rank: + skips: + - embed + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - down_sample + - up_sample + ipts: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample + smooth: + proj: + skips: + - embed + - resblock_shortcut + - resblock_time_proj + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - down_sample + - up_sample diff --git a/examples/diffusion/configs/model/sdxl_conv_only.yaml b/examples/diffusion/configs/model/sdxl_conv_only.yaml new file mode 100644 index 0000000..3690531 --- /dev/null +++ b/examples/diffusion/configs/model/sdxl_conv_only.yaml @@ -0,0 +1,79 @@ +cache: + root: /data/dongd/deepcompr/runs +output: + root: /data/dongd/deepcompr/runs +pipeline: + name: sdxl + dtype: torch.bfloat16 + device: "cuda:4" +eval: + num_steps: 50 + guidance_scale: 5.0 + protocol: EulerDiscrete{num_steps}-g{guidance_scale} + batch_size_per_gpu: 2 + num_gpus: 1 + num_samples: 100 + benchmarks: + - "/data/dongd/datasets/MJHQ30K.json" +quant: + develop_dtype: torch.bfloat16 + calib: + batch_size: 2 + combine: false + num_samples: 128 + path: /data/dongd/deepcompr/datasets/{dtype}/{model}/{protocol}/{data}/s128 # s128 is for qdiff calib data samples + wgts: + skips: + - embed + - attn + - ffn + - resblock_shortcut + - resblock_time_proj + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample + low_rank: + skips: + - embed + - attn + - ffn + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - down_sample + - up_sample + ipts: + skips: + - embed + - attn + - ffn + - resblock_shortcut + - resblock_time_proj + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - attn_add + - ffn_add + - down_sample + - up_sample + smooth: + proj: + skips: + - embed + - attn + - ffn + - resblock_shortcut + - resblock_time_proj + - transformer_proj_in + - transformer_proj_out + - transformer_norm + - transformer_add_norm + - down_sample + - up_sample