Skip to content

Commit 592a14b

Browse files
committed
removed unnecessary codes
1 parent 50011f5 commit 592a14b

File tree

7 files changed

+161
-290
lines changed

7 files changed

+161
-290
lines changed

ControlNeXt-SDXL-Training/models/unet.py

-70
Original file line numberDiff line numberDiff line change
@@ -53,76 +53,6 @@
5353

5454
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5555

56-
UNET_CONFIG = {
57-
"_class_name": "UNet2DConditionModel",
58-
"_diffusers_version": "0.19.0.dev0",
59-
"act_fn": "silu",
60-
"addition_embed_type": "text_time",
61-
"addition_embed_type_num_heads": 64,
62-
"addition_time_embed_dim": 256,
63-
"attention_head_dim": [
64-
5,
65-
10,
66-
20
67-
],
68-
"block_out_channels": [
69-
320,
70-
640,
71-
1280
72-
],
73-
"center_input_sample": False,
74-
"class_embed_type": None,
75-
"class_embeddings_concat": False,
76-
"conv_in_kernel": 3,
77-
"conv_out_kernel": 3,
78-
"cross_attention_dim": 2048,
79-
"cross_attention_norm": None,
80-
"down_block_types": [
81-
"DownBlock2D",
82-
"CrossAttnDownBlock2D",
83-
"CrossAttnDownBlock2D"
84-
],
85-
"downsample_padding": 1,
86-
"dual_cross_attention": False,
87-
"encoder_hid_dim": None,
88-
"encoder_hid_dim_type": None,
89-
"flip_sin_to_cos": True,
90-
"freq_shift": 0,
91-
"in_channels": 4,
92-
"layers_per_block": 2,
93-
"mid_block_only_cross_attention": None,
94-
"mid_block_scale_factor": 1,
95-
"mid_block_type": "UNetMidBlock2DCrossAttn",
96-
"norm_eps": 1e-05,
97-
"norm_num_groups": 32,
98-
"num_attention_heads": None,
99-
"num_class_embeds": None,
100-
"only_cross_attention": False,
101-
"out_channels": 4,
102-
"projection_class_embeddings_input_dim": 2816,
103-
"resnet_out_scale_factor": 1.0,
104-
"resnet_skip_time_act": False,
105-
"resnet_time_scale_shift": "default",
106-
"sample_size": 128,
107-
"time_cond_proj_dim": None,
108-
"time_embedding_act_fn": None,
109-
"time_embedding_dim": None,
110-
"time_embedding_type": "positional",
111-
"timestep_post_act": None,
112-
"transformer_layers_per_block": [
113-
1,
114-
2,
115-
10
116-
],
117-
"up_block_types": [
118-
"CrossAttnUpBlock2D",
119-
"CrossAttnUpBlock2D",
120-
"UpBlock2D"
121-
],
122-
"upcast_attention": None,
123-
"use_linear_projection": True
124-
}
125-
12656

12757
@dataclass
12858
class UNet2DConditionOutput(BaseOutput):

ControlNeXt-SDXL-Training/utils/tools.py

+78-7
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,78 @@
44
from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel
55
from safetensors.torch import load_file
66
from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline
7-
from models.unet import UNet2DConditionModel, UNET_CONFIG
7+
from models.unet import UNet2DConditionModel
88
from models.controlnet import ControlNetModel
99
from . import utils
1010

11+
UNET_CONFIG = {
12+
"act_fn": "silu",
13+
"addition_embed_type": "text_time",
14+
"addition_embed_type_num_heads": 64,
15+
"addition_time_embed_dim": 256,
16+
"attention_head_dim": [
17+
5,
18+
10,
19+
20
20+
],
21+
"block_out_channels": [
22+
320,
23+
640,
24+
1280
25+
],
26+
"center_input_sample": False,
27+
"class_embed_type": None,
28+
"class_embeddings_concat": False,
29+
"conv_in_kernel": 3,
30+
"conv_out_kernel": 3,
31+
"cross_attention_dim": 2048,
32+
"cross_attention_norm": None,
33+
"down_block_types": [
34+
"DownBlock2D",
35+
"CrossAttnDownBlock2D",
36+
"CrossAttnDownBlock2D"
37+
],
38+
"downsample_padding": 1,
39+
"dual_cross_attention": False,
40+
"encoder_hid_dim": None,
41+
"encoder_hid_dim_type": None,
42+
"flip_sin_to_cos": True,
43+
"freq_shift": 0,
44+
"in_channels": 4,
45+
"layers_per_block": 2,
46+
"mid_block_only_cross_attention": None,
47+
"mid_block_scale_factor": 1,
48+
"mid_block_type": "UNetMidBlock2DCrossAttn",
49+
"norm_eps": 1e-05,
50+
"norm_num_groups": 32,
51+
"num_attention_heads": None,
52+
"num_class_embeds": None,
53+
"only_cross_attention": False,
54+
"out_channels": 4,
55+
"projection_class_embeddings_input_dim": 2816,
56+
"resnet_out_scale_factor": 1.0,
57+
"resnet_skip_time_act": False,
58+
"resnet_time_scale_shift": "default",
59+
"sample_size": 128,
60+
"time_cond_proj_dim": None,
61+
"time_embedding_act_fn": None,
62+
"time_embedding_dim": None,
63+
"time_embedding_type": "positional",
64+
"timestep_post_act": None,
65+
"transformer_layers_per_block": [
66+
1,
67+
2,
68+
10
69+
],
70+
"up_block_types": [
71+
"CrossAttnUpBlock2D",
72+
"CrossAttnUpBlock2D",
73+
"UpBlock2D"
74+
],
75+
"upcast_attention": None,
76+
"use_linear_projection": True
77+
}
78+
1179
CONTROLNET_CONFIG = {
1280
'in_channels': [128, 128],
1381
'out_channels': [128, 256],
@@ -83,19 +151,22 @@ def get_pipeline(
83151

84152
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
85153
if unet_model_name_or_path is not None:
154+
print(f"loading controlnext unet from {unet_model_name_or_path}")
86155
pipeline.load_controlnext_unet_weights(
87156
unet_model_name_or_path,
88157
load_weight_increasement=load_weight_increasement,
89158
use_safetensors=True,
90159
torch_dtype=torch.float16,
91160
cache_dir=hf_cache_dir,
92161
)
93-
pipeline.load_controlnext_controlnet_weights(
94-
controlnet_model_name_or_path,
95-
use_safetensors=True,
96-
torch_dtype=torch.float32,
97-
cache_dir=hf_cache_dir,
98-
)
162+
if controlnet_model_name_or_path is not None:
163+
print(f"loading controlnext controlnet from {controlnet_model_name_or_path}")
164+
pipeline.load_controlnext_controlnet_weights(
165+
controlnet_model_name_or_path,
166+
use_safetensors=True,
167+
torch_dtype=torch.float32,
168+
cache_dir=hf_cache_dir,
169+
)
99170
pipeline.set_progress_bar_config()
100171
pipeline = pipeline.to(device, dtype=torch.float16)
101172

ControlNeXt-SDXL-Training/utils/utils.py

-68
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,5 @@
11
import math
22
from typing import Tuple, Union, Optional
3-
from safetensors.torch import load_file
4-
from transformers import PretrainedConfig
5-
6-
7-
def count_num_parameters_of_safetensors_model(safetensors_path):
8-
state_dict = load_file(safetensors_path)
9-
return sum(p.numel() for p in state_dict.values())
10-
11-
12-
def import_model_class_from_model_name_or_path(
13-
pretrained_model_name_or_path: str, revision: str, subfolder: str = None
14-
):
15-
text_encoder_config = PretrainedConfig.from_pretrained(
16-
pretrained_model_name_or_path, revision=revision, subfolder=subfolder
17-
)
18-
model_class = text_encoder_config.architectures[0]
19-
if model_class == "CLIPTextModel":
20-
from transformers import CLIPTextModel
21-
return CLIPTextModel
22-
elif model_class == "CLIPTextModelWithProjection":
23-
from transformers import CLIPTextModelWithProjection
24-
return CLIPTextModelWithProjection
25-
else:
26-
raise ValueError(f"{model_class} is not supported.")
27-
28-
29-
def fix_clip_text_encoder_position_ids(text_encoder):
30-
if hasattr(text_encoder.text_model.embeddings, "position_ids"):
31-
text_encoder.text_model.embeddings.position_ids = text_encoder.text_model.embeddings.position_ids.long()
32-
33-
34-
def load_controlnext_unet_state_dict(unet_sd, controlnext_unet_sd):
35-
assert all(
36-
k in unet_sd for k in controlnext_unet_sd), f"controlnext unet state dict is not compatible with unet state dict, missing keys: {set(controlnext_unet_sd.keys()) - set(unet_sd.keys())}, extra keys: {set(unet_sd.keys()) - set(controlnext_unet_sd.keys())}"
37-
for k in controlnext_unet_sd.keys():
38-
unet_sd[k] = controlnext_unet_sd[k]
39-
return unet_sd
40-
41-
42-
def convert_to_controlnext_unet_state_dict(state_dict):
43-
import re
44-
pattern = re.compile(r'.*attn2.*to_out.*')
45-
state_dict = {k: v for k, v in state_dict.items() if pattern.match(k)}
46-
# state_dict = extract_unet_state_dict(state_dict)
47-
if is_sdxl_state_dict(state_dict):
48-
state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
49-
return state_dict
503

514

525
def make_unet_conversion_map():
@@ -166,27 +119,6 @@ def extract_unet_state_dict(state_dict):
166119
return unet_sd
167120

168121

169-
def is_sdxl_state_dict(state_dict):
170-
return any(key.startswith('input_blocks') for key in state_dict.keys())
171-
172-
173-
def contains_unet_keys(state_dict):
174-
UNET_KEY_PREFIX = "model.diffusion_model."
175-
return any(k.startswith(UNET_KEY_PREFIX) for k in state_dict.keys())
176-
177-
178-
def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False):
179-
if not load_weight_increasement:
180-
state_dict = load_file(safetensors_path)
181-
model.load_state_dict(state_dict, strict=strict)
182-
else:
183-
state_dict = load_file(safetensors_path)
184-
pretrained_state_dict = model.state_dict()
185-
for k in state_dict.keys():
186-
state_dict[k] = state_dict[k] + pretrained_state_dict[k]
187-
model.load_state_dict(state_dict, strict=False)
188-
189-
190122
def log_model_info(model, name):
191123
sd = model.state_dict() if hasattr(model, "state_dict") else model
192124
print(

ControlNeXt-SDXL/models/unet.py

-70
Original file line numberDiff line numberDiff line change
@@ -53,76 +53,6 @@
5353

5454
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5555

56-
UNET_CONFIG = {
57-
"_class_name": "UNet2DConditionModel",
58-
"_diffusers_version": "0.19.0.dev0",
59-
"act_fn": "silu",
60-
"addition_embed_type": "text_time",
61-
"addition_embed_type_num_heads": 64,
62-
"addition_time_embed_dim": 256,
63-
"attention_head_dim": [
64-
5,
65-
10,
66-
20
67-
],
68-
"block_out_channels": [
69-
320,
70-
640,
71-
1280
72-
],
73-
"center_input_sample": False,
74-
"class_embed_type": None,
75-
"class_embeddings_concat": False,
76-
"conv_in_kernel": 3,
77-
"conv_out_kernel": 3,
78-
"cross_attention_dim": 2048,
79-
"cross_attention_norm": None,
80-
"down_block_types": [
81-
"DownBlock2D",
82-
"CrossAttnDownBlock2D",
83-
"CrossAttnDownBlock2D"
84-
],
85-
"downsample_padding": 1,
86-
"dual_cross_attention": False,
87-
"encoder_hid_dim": None,
88-
"encoder_hid_dim_type": None,
89-
"flip_sin_to_cos": True,
90-
"freq_shift": 0,
91-
"in_channels": 4,
92-
"layers_per_block": 2,
93-
"mid_block_only_cross_attention": None,
94-
"mid_block_scale_factor": 1,
95-
"mid_block_type": "UNetMidBlock2DCrossAttn",
96-
"norm_eps": 1e-05,
97-
"norm_num_groups": 32,
98-
"num_attention_heads": None,
99-
"num_class_embeds": None,
100-
"only_cross_attention": False,
101-
"out_channels": 4,
102-
"projection_class_embeddings_input_dim": 2816,
103-
"resnet_out_scale_factor": 1.0,
104-
"resnet_skip_time_act": False,
105-
"resnet_time_scale_shift": "default",
106-
"sample_size": 128,
107-
"time_cond_proj_dim": None,
108-
"time_embedding_act_fn": None,
109-
"time_embedding_dim": None,
110-
"time_embedding_type": "positional",
111-
"timestep_post_act": None,
112-
"transformer_layers_per_block": [
113-
1,
114-
2,
115-
10
116-
],
117-
"up_block_types": [
118-
"CrossAttnUpBlock2D",
119-
"CrossAttnUpBlock2D",
120-
"UpBlock2D"
121-
],
122-
"upcast_attention": None,
123-
"use_linear_projection": True
124-
}
125-
12656

12757
@dataclass
12858
class UNet2DConditionOutput(BaseOutput):

ControlNeXt-SDXL/run_controlnext.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import torch
33
import cv2
4+
import gc
45
import numpy as np
56
import argparse
67
from PIL import Image
@@ -112,6 +113,10 @@ def log_validation(
112113
print("Save images to:", file_path)
113114
cv2.imwrite(file_path, formatted_images)
114115

116+
gc.collect()
117+
if str(device) == 'cuda' and torch.cuda.is_available():
118+
torch.cuda.empty_cache()
119+
115120
return image_logs
116121

117122

0 commit comments

Comments
 (0)