|
14 | 14 |
|
15 | 15 | import inspect
|
16 | 16 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
| -from packaging import version |
18 | 17 | import torch
|
19 | 18 | from transformers import (
|
20 | 19 | CLIPImageProcessor,
|
|
57 | 56 | from diffusers.utils.torch_utils import randn_tensor
|
58 | 57 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
59 | 58 | from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 59 | +from huggingface_hub.utils import validate_hf_hub_args |
60 | 60 |
|
61 | 61 | if is_invisible_watermark_available():
|
62 | 62 | from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
|
87 | 87 | ```
|
88 | 88 | """
|
89 | 89 |
|
| 90 | +CONTROLNEXT_WEIGHT_NAME = "controlnet.bin" |
| 91 | +CONTROLNEXT_WEIGHT_NAME_SAFE = "controlnet.safetensors" |
| 92 | +UNET_WEIGHT_NAME = "unet.bin" |
| 93 | +UNET_WEIGHT_NAME_SAFE = "unet.safetensors" |
| 94 | + |
| 95 | + |
| 96 | +# Copied from https://github.com/kohya-ss/sd-scripts/blob/main/library/sdxl_model_util.py |
| 97 | + |
| 98 | +def is_sdxl_state_dict(state_dict): |
| 99 | + return any(key.startswith('input_blocks') for key in state_dict.keys()) |
| 100 | + |
| 101 | + |
| 102 | +def convert_sdxl_unet_state_dict_to_diffusers(sd): |
| 103 | + unet_conversion_map = make_unet_conversion_map() |
| 104 | + |
| 105 | + conversion_dict = {sd: hf for sd, hf in unet_conversion_map} |
| 106 | + return convert_unet_state_dict(sd, conversion_dict) |
| 107 | + |
| 108 | + |
| 109 | +def convert_unet_state_dict(src_sd, conversion_map): |
| 110 | + converted_sd = {} |
| 111 | + for src_key, value in src_sd.items(): |
| 112 | + src_key_fragments = src_key.split(".")[:-1] # remove weight/bias |
| 113 | + while len(src_key_fragments) > 0: |
| 114 | + src_key_prefix = ".".join(src_key_fragments) + "." |
| 115 | + if src_key_prefix in conversion_map: |
| 116 | + converted_prefix = conversion_map[src_key_prefix] |
| 117 | + converted_key = converted_prefix + src_key[len(src_key_prefix):] |
| 118 | + converted_sd[converted_key] = value |
| 119 | + break |
| 120 | + src_key_fragments.pop(-1) |
| 121 | + assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" |
| 122 | + |
| 123 | + return converted_sd |
| 124 | + |
| 125 | + |
| 126 | +def make_unet_conversion_map(): |
| 127 | + unet_conversion_map_layer = [] |
| 128 | + |
| 129 | + for i in range(3): # num_blocks is 3 in sdxl |
| 130 | + # loop over downblocks/upblocks |
| 131 | + for j in range(2): |
| 132 | + # loop over resnets/attentions for downblocks |
| 133 | + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." |
| 134 | + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." |
| 135 | + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) |
| 136 | + |
| 137 | + if i < 3: |
| 138 | + # no attention layers in down_blocks.3 |
| 139 | + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." |
| 140 | + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." |
| 141 | + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) |
| 142 | + |
| 143 | + for j in range(3): |
| 144 | + # loop over resnets/attentions for upblocks |
| 145 | + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." |
| 146 | + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." |
| 147 | + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) |
| 148 | + |
| 149 | + # if i > 0: commentout for sdxl |
| 150 | + # no attention layers in up_blocks.0 |
| 151 | + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." |
| 152 | + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." |
| 153 | + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) |
| 154 | + |
| 155 | + if i < 3: |
| 156 | + # no downsample in down_blocks.3 |
| 157 | + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." |
| 158 | + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." |
| 159 | + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) |
| 160 | + |
| 161 | + # no upsample in up_blocks.3 |
| 162 | + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." |
| 163 | + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl |
| 164 | + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) |
| 165 | + |
| 166 | + hf_mid_atn_prefix = "mid_block.attentions.0." |
| 167 | + sd_mid_atn_prefix = "middle_block.1." |
| 168 | + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) |
| 169 | + |
| 170 | + for j in range(2): |
| 171 | + hf_mid_res_prefix = f"mid_block.resnets.{j}." |
| 172 | + sd_mid_res_prefix = f"middle_block.{2*j}." |
| 173 | + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) |
| 174 | + |
| 175 | + unet_conversion_map_resnet = [ |
| 176 | + # (stable-diffusion, HF Diffusers) |
| 177 | + ("in_layers.0.", "norm1."), |
| 178 | + ("in_layers.2.", "conv1."), |
| 179 | + ("out_layers.0.", "norm2."), |
| 180 | + ("out_layers.3.", "conv2."), |
| 181 | + ("emb_layers.1.", "time_emb_proj."), |
| 182 | + ("skip_connection.", "conv_shortcut."), |
| 183 | + ] |
| 184 | + |
| 185 | + unet_conversion_map = [] |
| 186 | + for sd, hf in unet_conversion_map_layer: |
| 187 | + if "resnets" in hf: |
| 188 | + for sd_res, hf_res in unet_conversion_map_resnet: |
| 189 | + unet_conversion_map.append((sd + sd_res, hf + hf_res)) |
| 190 | + else: |
| 191 | + unet_conversion_map.append((sd, hf)) |
| 192 | + |
| 193 | + for j in range(2): |
| 194 | + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." |
| 195 | + sd_time_embed_prefix = f"time_embed.{j*2}." |
| 196 | + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) |
| 197 | + |
| 198 | + for j in range(2): |
| 199 | + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." |
| 200 | + sd_label_embed_prefix = f"label_emb.0.{j*2}." |
| 201 | + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) |
| 202 | + |
| 203 | + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) |
| 204 | + unet_conversion_map.append(("out.0.", "conv_norm_out.")) |
| 205 | + unet_conversion_map.append(("out.2.", "conv_out.")) |
| 206 | + |
| 207 | + return unet_conversion_map |
90 | 208 |
|
91 | 209 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
| 210 | + |
| 211 | + |
92 | 212 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
93 | 213 | """
|
94 | 214 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
@@ -280,6 +400,156 @@ def __init__(
|
280 | 400 | else:
|
281 | 401 | self.watermark = None
|
282 | 402 |
|
| 403 | + def load_controlnext_weights( |
| 404 | + self, |
| 405 | + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], |
| 406 | + load_weight_increasement: bool = False, |
| 407 | + **kwargs, |
| 408 | + ): |
| 409 | + self.load_controlnext_unet_weights(pretrained_model_name_or_path_or_dict, load_weight_increasement, **kwargs) |
| 410 | + kwargs['torch_dtype'] = torch.float32 |
| 411 | + self.load_controlnext_controlnet_weights(pretrained_model_name_or_path_or_dict, **kwargs) |
| 412 | + |
| 413 | + def load_controlnext_unet_weights( |
| 414 | + self, |
| 415 | + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], |
| 416 | + load_weight_increasement: bool = False, |
| 417 | + **kwargs, |
| 418 | + ): |
| 419 | + if isinstance(pretrained_model_name_or_path_or_dict, dict): |
| 420 | + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() |
| 421 | + |
| 422 | + state_dict = self.controlnext_unet_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) |
| 423 | + if is_sdxl_state_dict(state_dict): |
| 424 | + state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict) |
| 425 | + |
| 426 | + logger.info(f"Loading ControlNeXt UNet" + (f" with weight increasement." if load_weight_increasement else ".")) |
| 427 | + if load_weight_increasement: |
| 428 | + unet_sd = self.unet.state_dict() |
| 429 | + for k in state_dict.keys(): |
| 430 | + state_dict[k] = state_dict[k] + unet_sd[k] |
| 431 | + self.unet.load_state_dict(state_dict, strict=False) |
| 432 | + |
| 433 | + @classmethod |
| 434 | + @validate_hf_hub_args |
| 435 | + def controlnext_unet_state_dict( |
| 436 | + cls, |
| 437 | + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], |
| 438 | + **kwargs, |
| 439 | + ): |
| 440 | + if 'weight_name' not in kwargs: |
| 441 | + kwargs['weight_name'] = UNET_WEIGHT_NAME_SAFE if kwargs.get('use_safetensors', False) else UNET_WEIGHT_NAME |
| 442 | + return cls.controlnext_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) |
| 443 | + |
| 444 | + def load_controlnext_controlnet_weights( |
| 445 | + self, |
| 446 | + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], |
| 447 | + **kwargs, |
| 448 | + ): |
| 449 | + if self.controlnet is None: |
| 450 | + raise ValueError("No ControlNeXt ControlNet found in the pipeline.") |
| 451 | + if isinstance(pretrained_model_name_or_path_or_dict, dict): |
| 452 | + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() |
| 453 | + |
| 454 | + state_dict = self.controlnext_controlnet_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) |
| 455 | + |
| 456 | + logger.info(f"Loading ControlNeXt ControlNet") |
| 457 | + self.controlnet.load_state_dict(state_dict, strict=True) |
| 458 | + |
| 459 | + @classmethod |
| 460 | + @validate_hf_hub_args |
| 461 | + def controlnext_controlnet_state_dict( |
| 462 | + cls, |
| 463 | + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], |
| 464 | + **kwargs, |
| 465 | + ): |
| 466 | + if 'weight_name' not in kwargs: |
| 467 | + kwargs['weight_name'] = CONTROLNEXT_WEIGHT_NAME_SAFE if kwargs.get('use_safetensors', False) else CONTROLNEXT_WEIGHT_NAME |
| 468 | + return cls.controlnext_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) |
| 469 | + |
| 470 | + @classmethod |
| 471 | + @validate_hf_hub_args |
| 472 | + def controlnext_state_dict( |
| 473 | + cls, |
| 474 | + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], |
| 475 | + **kwargs, |
| 476 | + ): |
| 477 | + r""" |
| 478 | + Return state dict for controlnext weights. |
| 479 | +
|
| 480 | + Parameters: |
| 481 | + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): |
| 482 | + Can be either: |
| 483 | +
|
| 484 | + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on |
| 485 | + the Hub. |
| 486 | + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved |
| 487 | + with [`ModelMixin.save_pretrained`]. |
| 488 | + - A [torch state |
| 489 | + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). |
| 490 | +
|
| 491 | + cache_dir (`Union[str, os.PathLike]`, *optional*): |
| 492 | + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache |
| 493 | + is not used. |
| 494 | + force_download (`bool`, *optional*, defaults to `False`): |
| 495 | + Whether or not to force the (re-)download of the model weights and configuration files, overriding the |
| 496 | + cached versions if they exist. |
| 497 | +
|
| 498 | + proxies (`Dict[str, str]`, *optional*): |
| 499 | + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', |
| 500 | + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. |
| 501 | + local_files_only (`bool`, *optional*, defaults to `False`): |
| 502 | + Whether to only load local model weights and configuration files or not. If set to `True`, the model |
| 503 | + won't be downloaded from the Hub. |
| 504 | + token (`str` or *bool*, *optional*): |
| 505 | + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from |
| 506 | + `diffusers-cli login` (stored in `~/.huggingface`) is used. |
| 507 | + revision (`str`, *optional*, defaults to `"main"`): |
| 508 | + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier |
| 509 | + allowed by Git. |
| 510 | + subfolder (`str`, *optional*, defaults to `""`): |
| 511 | + The subfolder location of a model file within a larger model repository on the Hub or locally. |
| 512 | + weight_name (`str`, *optional*, defaults to None): |
| 513 | + Name of the serialized state dict file. |
| 514 | + """ |
| 515 | + cache_dir = kwargs.pop("cache_dir", None) |
| 516 | + force_download = kwargs.pop("force_download", False) |
| 517 | + proxies = kwargs.pop("proxies", None) |
| 518 | + local_files_only = kwargs.pop("local_files_only", None) |
| 519 | + token = kwargs.pop("token", None) |
| 520 | + revision = kwargs.pop("revision", None) |
| 521 | + subfolder = kwargs.pop("subfolder", None) |
| 522 | + weight_name = kwargs.pop("weight_name", None) |
| 523 | + unet_config = kwargs.pop("unet_config", None) |
| 524 | + use_safetensors = kwargs.pop("use_safetensors", None) |
| 525 | + |
| 526 | + allow_pickle = False |
| 527 | + if use_safetensors is None: |
| 528 | + use_safetensors = True |
| 529 | + allow_pickle = True |
| 530 | + |
| 531 | + user_agent = { |
| 532 | + "file_type": "attn_procs_weights", |
| 533 | + "framework": "pytorch", |
| 534 | + } |
| 535 | + |
| 536 | + state_dict = cls._fetch_state_dict( |
| 537 | + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, |
| 538 | + weight_name=weight_name, |
| 539 | + use_safetensors=use_safetensors, |
| 540 | + local_files_only=local_files_only, |
| 541 | + cache_dir=cache_dir, |
| 542 | + force_download=force_download, |
| 543 | + proxies=proxies, |
| 544 | + token=token, |
| 545 | + revision=revision, |
| 546 | + subfolder=subfolder, |
| 547 | + user_agent=user_agent, |
| 548 | + allow_pickle=allow_pickle, |
| 549 | + ) |
| 550 | + |
| 551 | + return state_dict |
| 552 | + |
283 | 553 | def prepare_image(
|
284 | 554 | self,
|
285 | 555 | image,
|
|
0 commit comments