Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion deepcompressor/app/diffusion/nn/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock

from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d
from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d, Conv2dAsLinear
from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear
from deepcompressor.utils import patch, tools

from .attention import DiffusionAttentionProcessor
from .struct import DiffusionFeedForwardStruct, DiffusionModelStruct, DiffusionResnetStruct, UNetStruct
from diffusers.pipelines import StableDiffusionXLPipeline
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.models.resnet import ResnetBlock2D

__all__ = [
"replace_up_block_conv_with_concat_conv",
"replace_fused_linear_with_concat_linear",
"replace_attn_processor",
"shift_input_activations",
"replace_conv2d_with_equivalent_linear_in_sdxl",
]


Expand Down Expand Up @@ -52,6 +56,64 @@ def replace_up_block_conv_with_concat_conv(model: nn.Module) -> None:
tools.logging.Formatter.indent_dec()
tools.logging.Formatter.indent_dec()
tools.logging.Formatter.indent_dec()


def replace_conv2d_with_equivalent_linear_in_sdxl(pipeline: StableDiffusionXLPipeline) -> None:
"""Replace Conv2d layers in SDXL UNet2DConditionModel with equivalent linear."""
if not isinstance(pipeline, StableDiffusionXLPipeline):
return
logger = tools.logging.getLogger(__name__)
logger.info("Replacing Conv2d with equivalent linear.")
tools.logging.Formatter.indent_inc()
assert isinstance(pipeline.unet, UNet2DConditionModel)
model: UNet2DConditionModel = pipeline.unet
# logger.info("replacing conv_in")
# model.conv_in = Conv2dAsLinear(model.conv_in)
logger.info("+ replacing down_blocks")

tools.logging.Formatter.indent_inc()
for i, down_block in enumerate(model.down_blocks):
assert hasattr(down_block, "resnets")
assert isinstance(getattr(down_block, "resnets"), nn.ModuleList)
for j, resnet in enumerate(getattr(down_block, "resnets")):
assert isinstance(resnet, ResnetBlock2D)
logger.info(f"- replacing down_blocks.{i}.resnets.{j}.conv1")
resnet.conv1 = Conv2dAsLinear(resnet.conv1)
logger.info(f"- replacing down_blocks.{i}.resnets.{j}.conv2")
resnet.conv2 = Conv2dAsLinear(resnet.conv2)
# logger.info(f"- replacing down_blocks.{i}.resnets.{j}.conv_shortcut")
# resnet.conv_shortcut = _convert(resnet.conv_shortcut)
tools.logging.Formatter.indent_dec()

logger.info("+ replacing mid_block")
assert hasattr(model.mid_block, "resnets")
assert isinstance(getattr(model.mid_block, "resnets"), nn.ModuleList)
tools.logging.Formatter.indent_inc()
for j, resnet in enumerate(getattr(model.mid_block, "resnets")):
assert isinstance(resnet, ResnetBlock2D)
logger.info(f"- replacing mid_block.resnets.{j}.conv1")
resnet.conv1 = Conv2dAsLinear(resnet.conv1)
logger.info(f"- replacing mid_block.resnets.{j}.conv2")
resnet.conv2 = Conv2dAsLinear(resnet.conv2)
tools.logging.Formatter.indent_dec()

logger.info("+ replacing up_blocks")
tools.logging.Formatter.indent_inc()
for i, up_block in enumerate(model.up_blocks):
assert hasattr(up_block, "resnets")
assert isinstance(getattr(up_block, "resnets"), nn.ModuleList)
for j, resnet in enumerate(getattr(up_block, "resnets")):
assert isinstance(resnet, ResnetBlock2D)
logger.info(f"- replacing up_blocks.{i}.resnets.{j}.conv1")
resnet.conv1 = Conv2dAsLinear(resnet.conv1)
logger.info(f"- replacing up_blocks.{i}.resnets.{j}.conv2")
resnet.conv2 = Conv2dAsLinear(resnet.conv2)
# logger.info(f"- replacing up_blocks.{i}.resnets.{j}.conv_shortcut")
# resnet.conv_shortcut = _convert(resnet.conv_shortcut)
tools.logging.Formatter.indent_dec()
# logger.info("replacing conv_out")
# model.conv_out = Conv2dAsLinear(model.conv_out)
tools.logging.Formatter.indent_dec()


def replace_fused_linear_with_concat_linear(model: nn.Module) -> None:
Expand Down
21 changes: 16 additions & 5 deletions deepcompressor/app/diffusion/nn/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
StableDiffusionXLPipeline,
)

from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d
from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d, Conv2dAsLinear
from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear
from deepcompressor.nn.struct.attn import (
AttentionConfigStruct,
Expand All @@ -70,6 +70,7 @@
)
from deepcompressor.nn.struct.base import BaseModuleStruct
from deepcompressor.utils.common import join_name
from deepcompressor.utils import tools

from .attention import DiffusionAttentionProcessor

Expand Down Expand Up @@ -966,6 +967,7 @@ def construct(
idx: int = 0,
**kwargs,
) -> "DiffusionResnetStruct":
logger = tools.logging.getLogger(__name__)
if isinstance(module, ResnetBlock2D):
assert module.upsample is None, "upsample must be None"
assert module.downsample is None, "downsample must be None"
Expand All @@ -986,6 +988,9 @@ def construct(
shifted = True
conv1_convs = [module.conv1.conv]
conv1_names = ["conv1.conv"]
elif isinstance(module.conv1, Conv2dAsLinear):
conv1_convs, conv1_names = [module.conv1], ["conv1"]
logger.info(f"construct conv1 Conv2dAsLinear in DiffusionResnetStruct, fname={fname}, rname={rname}, rkey={rkey}, idx={idx}")
else:
assert isinstance(module.conv1, nn.Conv2d)
conv1_convs, conv1_names = [module.conv1], ["conv1"]
Expand All @@ -1004,8 +1009,11 @@ def construct(
shifted = True
conv2_convs = [module.conv2.conv]
conv2_names = ["conv2.conv"]
elif isinstance(module.conv2, Conv2dAsLinear):
conv2_convs, conv2_names = [module.conv2], ["conv2"]
logger.info(f"construct conv2 Conv2dAsLinear in DiffusionResnetStruct, fname={fname}, rname={rname}, rkey={rkey}, idx={idx}")
else:
assert isinstance(module.conv2, nn.Conv2d)
assert isinstance(module.conv2, (nn.Conv2d, Conv2dAsLinear))
conv2_convs, conv2_names = [module.conv2], ["conv2"]
convs, conv_rnames = [conv1_convs, conv2_convs], [conv1_names, conv2_names]
norms, norm_rnames = [module.norm1, module.norm2], ["norm1", "norm2"]
Expand All @@ -1017,8 +1025,8 @@ def construct(
else:
raise NotImplementedError(f"Unsupported module type: {type(module)}")
config = FeedForwardConfigStruct(
hidden_size=convs[0][0].weight.shape[1],
intermediate_size=convs[0][0].weight.shape[0],
hidden_size=convs[0][0].linear.weight.shape[1], # TODO verify
intermediate_size=convs[0][0].linear.weight.shape[0], # TODO verify
intermediate_act_type=act_type,
num_experts=1,
)
Expand Down Expand Up @@ -1518,6 +1526,7 @@ def _default_construct(
@classmethod
def _get_default_key_map(cls) -> dict[str, set[str]]:
"""Get the default allowed keys."""
logger = tools.logging.getLogger(__name__)
key_map: dict[str, set[str]] = defaultdict(set)
for idx, (block_key, block_cls) in enumerate(
(
Expand Down Expand Up @@ -1561,7 +1570,9 @@ def _get_default_key_map(cls) -> dict[str, set[str]]:
if key in key_map:
key_map[key].clear()
key_map[key].add(key)
return {k: v for k, v in key_map.items() if v}
ret = {k: v for k, v in key_map.items() if v}
logger.info(f">> key_map: {ret}")
return ret


@dataclass(kw_only=True)
Expand Down
18 changes: 18 additions & 0 deletions deepcompressor/app/diffusion/pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""Diffusion pipeline configuration module."""

import gc
import json
import os
import typing as tp
from dataclasses import dataclass, field

Expand All @@ -11,12 +13,16 @@
DiffusionPipeline,
FluxControlPipeline,
FluxFillPipeline,
FluxPipeline,
SanaPipeline,
)
from diffusers.models.transformers import FluxTransformer2DModel
from omniconfig import configclass
from torch import nn
from transformers import PreTrainedModel, PreTrainedTokenizer, T5EncoderModel
from safetensors.torch import load_file

from deepcompressor.app.diffusion.pipeline.utils.mapping import comfy_to_diffusers
from deepcompressor.data.utils.dtype import eval_dtype
from deepcompressor.quantizer.processor import Quantizer
from deepcompressor.utils import tools
Expand All @@ -28,6 +34,7 @@
replace_fused_linear_with_concat_linear,
replace_up_block_conv_with_concat_conv,
shift_input_activations,
replace_conv2d_with_equivalent_linear_in_sdxl,
)

__all__ = ["DiffusionPipelineConfig"]
Expand Down Expand Up @@ -357,6 +364,16 @@ def _default_build(
pipeline.text_encoder.to(dtype)
else:
pipeline = SanaPipeline.from_pretrained(path, torch_dtype=dtype)
elif name == "flux.1-dev-custom":
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as f:
json_config = json.load(f)
transformer = FluxTransformer2DModel.from_config(json_config).to(dtype)
checkpoint_file = os.path.basename(path) + ".safetensors"
state_dict = comfy_to_diffusers(load_file(os.path.join(path, checkpoint_file)))
transformer.load_state_dict(state_dict)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
)
else:
pipeline = AutoPipelineForText2Image.from_pretrained(path, torch_dtype=dtype)
pipeline = pipeline.to(device)
Expand All @@ -365,6 +382,7 @@ def _default_build(
# The quantization and inference for resblock_conv layer is not completed.
# So here we do not do any pre process to conv layers.
# replace_up_block_conv_with_concat_conv(model)
replace_conv2d_with_equivalent_linear_in_sdxl(pipeline)
if shift_activations:
shift_input_activations(model)
return pipeline
Expand Down
3 changes: 3 additions & 0 deletions deepcompressor/app/diffusion/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,9 @@ def main(config: DiffusionPtqRunConfig, logging_level: int = tools.logging.DEBUG
try:
main(config, logging_level=tools.logging.DEBUG)
except Exception as e:
torch.cuda.synchronize()
torch.cuda.memory._dump_snapshot("/data/dongd/ptq_err_snapshot.pickle")
torch.cuda.memory._record_memory_history(enabled=None)
tools.logging.Formatter.indent_reset()
tools.logging.error("=== Error ===")
tools.logging.error(traceback.format_exc())
Expand Down
13 changes: 13 additions & 0 deletions deepcompressor/app/diffusion/quant/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from deepcompressor.data.cache import IOTensorsCache
from deepcompressor.data.common import TensorType
from deepcompressor.nn.patch.conv import Conv2dAsLinear
from deepcompressor.utils import tools
from deepcompressor.utils.common import join_name

from ..nn.struct import (
DiffusionAttentionStruct,
Expand Down Expand Up @@ -124,11 +126,20 @@ def quantize_diffusion_block_activations( # noqa: C901
assert field_name == "add_k_proj"
assert module_name == parent.add_k_proj_name
modules, module_names = parent.add_qkv_proj, parent.add_qkv_proj_names
elif isinstance(module, Conv2dAsLinear):
modules = [module.linear]
module_names = [join_name(module_name, "linear")]
eval_module = module
eval_name = module_name
eval_kwargs = layer_kwargs if layer_kwargs else {}
if modules is None:
assert module not in used_modules
used_modules.add(module)
orig_wgts = [(module.weight, orig_state_dict[f"{module_name}.weight"])] if orig_state_dict else None
args_caches.append((module_key, In, [module], [module_name], module, module_name, None, orig_wgts))
elif isinstance(module , Conv2dAsLinear):
orig_wgts = [(module.linear.weight, orig_state_dict[f"{join_name(module_name, "linear")}.weight"])] if orig_state_dict else None
args_caches.append((module_key, In, modules, module_names, eval_module, eval_name, eval_kwargs, orig_wgts))
else:
orig_wgts = []
for proj_module in modules:
Expand Down Expand Up @@ -166,6 +177,8 @@ def quantize_diffusion_block_activations( # noqa: C901
key=module_key,
tensor_type=tensor_type,
)
if isinstance(eval_module, Conv2dAsLinear):
logger.info(f"[calib_dynamic_range]Conv2dAsLinear eval_name: {eval_name}, module_names: {module_names}, quantizer.is_enabled: {quantizer.is_enabled()}")
if quantizer.is_enabled():
if cache_keys[0] not in quantizer_state_dict:
logger.debug("- Calibrating %s", ", ".join(cache_keys))
Expand Down
Loading