Skip to content

Commit 50011f5

Browse files
committed
simplified model loading from pipeline
1 parent 901940f commit 50011f5

File tree

10 files changed

+626
-224
lines changed

10 files changed

+626
-224
lines changed

ControlNeXt-SDXL-Training/README.md

+4-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ The output will be saved in `train/example`.
2929

3030
## Usage
3131

32+
We recommend to only save & load the weights difference of the UNet's trainable parameters, i.e., $\Delta W = W_{finetune} - W_{pretrained}$, rather than the actual weight.
33+
This is useful when adapting to various base models since the weights difference is model-agnostic.
34+
3235
```python
3336
accelerate launch train_controlnext.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-xl-base-1.0" \
3437
--pretrained_vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \
@@ -40,7 +43,7 @@ accelerate launch train_controlnext.py --pretrained_model_name_or_path "stabilit
4043
--gradient_checkpointing \
4144
--set_grads_to_none \
4245
--proportion_empty_prompts 0.2 \
43-
--controlnet_scale_factor 1.0 \
46+
--controlnet_scale_factor 1.0 \ # the strength of the controlnet output. For depth, we recommend 1.0, and for canny, we recommend 0.35
4447
--save_weights_increaments \
4548
--mixed_precision fp16 \
4649
--enable_xformers_memory_efficient_attention \
@@ -51,7 +54,3 @@ accelerate launch train_controlnext.py --pretrained_model_name_or_path "stabilit
5154
--validation_prompt "a stone tower on a rocky island" \
5255
--validation_image "examples/vidit_depth/condition_0.png"
5356
```
54-
55-
> --pretrained*model_name_or_path : pretrained base model \
56-
> --controlnet_scale_factor : the strength of the controlnet output. For depth, we recommend 1.0, and for canny, we recommend 0.35 \
57-
> --save_weights_increaments : whether to save the trainable parameters of unet directly or just the weight increments, i.e., $W*{finetune} - W\_{pretrained}$. This is useful when adapting to various base models.

ControlNeXt-SDXL-Training/pipeline/pipeline_controlnext.py

+271-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import inspect
1616
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17-
from packaging import version
1817
import torch
1918
from transformers import (
2019
CLIPImageProcessor,
@@ -57,6 +56,7 @@
5756
from diffusers.utils.torch_utils import randn_tensor
5857
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
5958
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
59+
from huggingface_hub.utils import validate_hf_hub_args
6060

6161
if is_invisible_watermark_available():
6262
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
@@ -87,8 +87,128 @@
8787
```
8888
"""
8989

90+
CONTROLNEXT_WEIGHT_NAME = "controlnet.bin"
91+
CONTROLNEXT_WEIGHT_NAME_SAFE = "controlnet.safetensors"
92+
UNET_WEIGHT_NAME = "unet.bin"
93+
UNET_WEIGHT_NAME_SAFE = "unet.safetensors"
94+
95+
96+
# Copied from https://github.com/kohya-ss/sd-scripts/blob/main/library/sdxl_model_util.py
97+
98+
def is_sdxl_state_dict(state_dict):
99+
return any(key.startswith('input_blocks') for key in state_dict.keys())
100+
101+
102+
def convert_sdxl_unet_state_dict_to_diffusers(sd):
103+
unet_conversion_map = make_unet_conversion_map()
104+
105+
conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
106+
return convert_unet_state_dict(sd, conversion_dict)
107+
108+
109+
def convert_unet_state_dict(src_sd, conversion_map):
110+
converted_sd = {}
111+
for src_key, value in src_sd.items():
112+
src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
113+
while len(src_key_fragments) > 0:
114+
src_key_prefix = ".".join(src_key_fragments) + "."
115+
if src_key_prefix in conversion_map:
116+
converted_prefix = conversion_map[src_key_prefix]
117+
converted_key = converted_prefix + src_key[len(src_key_prefix):]
118+
converted_sd[converted_key] = value
119+
break
120+
src_key_fragments.pop(-1)
121+
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
122+
123+
return converted_sd
124+
125+
126+
def make_unet_conversion_map():
127+
unet_conversion_map_layer = []
128+
129+
for i in range(3): # num_blocks is 3 in sdxl
130+
# loop over downblocks/upblocks
131+
for j in range(2):
132+
# loop over resnets/attentions for downblocks
133+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
134+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
135+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
136+
137+
if i < 3:
138+
# no attention layers in down_blocks.3
139+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
140+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
141+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
142+
143+
for j in range(3):
144+
# loop over resnets/attentions for upblocks
145+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
146+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
147+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
148+
149+
# if i > 0: commentout for sdxl
150+
# no attention layers in up_blocks.0
151+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
152+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
153+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
154+
155+
if i < 3:
156+
# no downsample in down_blocks.3
157+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
158+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
159+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
160+
161+
# no upsample in up_blocks.3
162+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
163+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
164+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
165+
166+
hf_mid_atn_prefix = "mid_block.attentions.0."
167+
sd_mid_atn_prefix = "middle_block.1."
168+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
169+
170+
for j in range(2):
171+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
172+
sd_mid_res_prefix = f"middle_block.{2*j}."
173+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
174+
175+
unet_conversion_map_resnet = [
176+
# (stable-diffusion, HF Diffusers)
177+
("in_layers.0.", "norm1."),
178+
("in_layers.2.", "conv1."),
179+
("out_layers.0.", "norm2."),
180+
("out_layers.3.", "conv2."),
181+
("emb_layers.1.", "time_emb_proj."),
182+
("skip_connection.", "conv_shortcut."),
183+
]
184+
185+
unet_conversion_map = []
186+
for sd, hf in unet_conversion_map_layer:
187+
if "resnets" in hf:
188+
for sd_res, hf_res in unet_conversion_map_resnet:
189+
unet_conversion_map.append((sd + sd_res, hf + hf_res))
190+
else:
191+
unet_conversion_map.append((sd, hf))
192+
193+
for j in range(2):
194+
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
195+
sd_time_embed_prefix = f"time_embed.{j*2}."
196+
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
197+
198+
for j in range(2):
199+
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
200+
sd_label_embed_prefix = f"label_emb.0.{j*2}."
201+
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
202+
203+
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
204+
unet_conversion_map.append(("out.0.", "conv_norm_out."))
205+
unet_conversion_map.append(("out.2.", "conv_out."))
206+
207+
return unet_conversion_map
90208

91209
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
210+
211+
92212
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
93213
"""
94214
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -280,6 +400,156 @@ def __init__(
280400
else:
281401
self.watermark = None
282402

403+
def load_controlnext_weights(
404+
self,
405+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
406+
load_weight_increasement: bool = False,
407+
**kwargs,
408+
):
409+
self.load_controlnext_unet_weights(pretrained_model_name_or_path_or_dict, load_weight_increasement, **kwargs)
410+
kwargs['torch_dtype'] = torch.float32
411+
self.load_controlnext_controlnet_weights(pretrained_model_name_or_path_or_dict, **kwargs)
412+
413+
def load_controlnext_unet_weights(
414+
self,
415+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
416+
load_weight_increasement: bool = False,
417+
**kwargs,
418+
):
419+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
420+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
421+
422+
state_dict = self.controlnext_unet_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
423+
if is_sdxl_state_dict(state_dict):
424+
state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
425+
426+
logger.info(f"Loading ControlNeXt UNet" + (f" with weight increasement." if load_weight_increasement else "."))
427+
if load_weight_increasement:
428+
unet_sd = self.unet.state_dict()
429+
for k in state_dict.keys():
430+
state_dict[k] = state_dict[k] + unet_sd[k]
431+
self.unet.load_state_dict(state_dict, strict=False)
432+
433+
@classmethod
434+
@validate_hf_hub_args
435+
def controlnext_unet_state_dict(
436+
cls,
437+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
438+
**kwargs,
439+
):
440+
if 'weight_name' not in kwargs:
441+
kwargs['weight_name'] = UNET_WEIGHT_NAME_SAFE if kwargs.get('use_safetensors', False) else UNET_WEIGHT_NAME
442+
return cls.controlnext_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
443+
444+
def load_controlnext_controlnet_weights(
445+
self,
446+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
447+
**kwargs,
448+
):
449+
if self.controlnet is None:
450+
raise ValueError("No ControlNeXt ControlNet found in the pipeline.")
451+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
452+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
453+
454+
state_dict = self.controlnext_controlnet_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
455+
456+
logger.info(f"Loading ControlNeXt ControlNet")
457+
self.controlnet.load_state_dict(state_dict, strict=True)
458+
459+
@classmethod
460+
@validate_hf_hub_args
461+
def controlnext_controlnet_state_dict(
462+
cls,
463+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
464+
**kwargs,
465+
):
466+
if 'weight_name' not in kwargs:
467+
kwargs['weight_name'] = CONTROLNEXT_WEIGHT_NAME_SAFE if kwargs.get('use_safetensors', False) else CONTROLNEXT_WEIGHT_NAME
468+
return cls.controlnext_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
469+
470+
@classmethod
471+
@validate_hf_hub_args
472+
def controlnext_state_dict(
473+
cls,
474+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
475+
**kwargs,
476+
):
477+
r"""
478+
Return state dict for controlnext weights.
479+
480+
Parameters:
481+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
482+
Can be either:
483+
484+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
485+
the Hub.
486+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
487+
with [`ModelMixin.save_pretrained`].
488+
- A [torch state
489+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
490+
491+
cache_dir (`Union[str, os.PathLike]`, *optional*):
492+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
493+
is not used.
494+
force_download (`bool`, *optional*, defaults to `False`):
495+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
496+
cached versions if they exist.
497+
498+
proxies (`Dict[str, str]`, *optional*):
499+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
500+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
501+
local_files_only (`bool`, *optional*, defaults to `False`):
502+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
503+
won't be downloaded from the Hub.
504+
token (`str` or *bool*, *optional*):
505+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
506+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
507+
revision (`str`, *optional*, defaults to `"main"`):
508+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
509+
allowed by Git.
510+
subfolder (`str`, *optional*, defaults to `""`):
511+
The subfolder location of a model file within a larger model repository on the Hub or locally.
512+
weight_name (`str`, *optional*, defaults to None):
513+
Name of the serialized state dict file.
514+
"""
515+
cache_dir = kwargs.pop("cache_dir", None)
516+
force_download = kwargs.pop("force_download", False)
517+
proxies = kwargs.pop("proxies", None)
518+
local_files_only = kwargs.pop("local_files_only", None)
519+
token = kwargs.pop("token", None)
520+
revision = kwargs.pop("revision", None)
521+
subfolder = kwargs.pop("subfolder", None)
522+
weight_name = kwargs.pop("weight_name", None)
523+
unet_config = kwargs.pop("unet_config", None)
524+
use_safetensors = kwargs.pop("use_safetensors", None)
525+
526+
allow_pickle = False
527+
if use_safetensors is None:
528+
use_safetensors = True
529+
allow_pickle = True
530+
531+
user_agent = {
532+
"file_type": "attn_procs_weights",
533+
"framework": "pytorch",
534+
}
535+
536+
state_dict = cls._fetch_state_dict(
537+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
538+
weight_name=weight_name,
539+
use_safetensors=use_safetensors,
540+
local_files_only=local_files_only,
541+
cache_dir=cache_dir,
542+
force_download=force_download,
543+
proxies=proxies,
544+
token=token,
545+
revision=revision,
546+
subfolder=subfolder,
547+
user_agent=user_agent,
548+
allow_pickle=allow_pickle,
549+
)
550+
551+
return state_dict
552+
283553
def prepare_image(
284554
self,
285555
image,

0 commit comments

Comments
 (0)