|
4 | 4 | from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel
|
5 | 5 | from safetensors.torch import load_file
|
6 | 6 | from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline
|
7 |
| -from models.unet import UNet2DConditionModel, UNET_CONFIG |
| 7 | +from models.unet import UNet2DConditionModel |
8 | 8 | from models.controlnet import ControlNetModel
|
9 | 9 | from . import utils
|
10 | 10 |
|
| 11 | +UNET_CONFIG = { |
| 12 | + "act_fn": "silu", |
| 13 | + "addition_embed_type": "text_time", |
| 14 | + "addition_embed_type_num_heads": 64, |
| 15 | + "addition_time_embed_dim": 256, |
| 16 | + "attention_head_dim": [ |
| 17 | + 5, |
| 18 | + 10, |
| 19 | + 20 |
| 20 | + ], |
| 21 | + "block_out_channels": [ |
| 22 | + 320, |
| 23 | + 640, |
| 24 | + 1280 |
| 25 | + ], |
| 26 | + "center_input_sample": False, |
| 27 | + "class_embed_type": None, |
| 28 | + "class_embeddings_concat": False, |
| 29 | + "conv_in_kernel": 3, |
| 30 | + "conv_out_kernel": 3, |
| 31 | + "cross_attention_dim": 2048, |
| 32 | + "cross_attention_norm": None, |
| 33 | + "down_block_types": [ |
| 34 | + "DownBlock2D", |
| 35 | + "CrossAttnDownBlock2D", |
| 36 | + "CrossAttnDownBlock2D" |
| 37 | + ], |
| 38 | + "downsample_padding": 1, |
| 39 | + "dual_cross_attention": False, |
| 40 | + "encoder_hid_dim": None, |
| 41 | + "encoder_hid_dim_type": None, |
| 42 | + "flip_sin_to_cos": True, |
| 43 | + "freq_shift": 0, |
| 44 | + "in_channels": 4, |
| 45 | + "layers_per_block": 2, |
| 46 | + "mid_block_only_cross_attention": None, |
| 47 | + "mid_block_scale_factor": 1, |
| 48 | + "mid_block_type": "UNetMidBlock2DCrossAttn", |
| 49 | + "norm_eps": 1e-05, |
| 50 | + "norm_num_groups": 32, |
| 51 | + "num_attention_heads": None, |
| 52 | + "num_class_embeds": None, |
| 53 | + "only_cross_attention": False, |
| 54 | + "out_channels": 4, |
| 55 | + "projection_class_embeddings_input_dim": 2816, |
| 56 | + "resnet_out_scale_factor": 1.0, |
| 57 | + "resnet_skip_time_act": False, |
| 58 | + "resnet_time_scale_shift": "default", |
| 59 | + "sample_size": 128, |
| 60 | + "time_cond_proj_dim": None, |
| 61 | + "time_embedding_act_fn": None, |
| 62 | + "time_embedding_dim": None, |
| 63 | + "time_embedding_type": "positional", |
| 64 | + "timestep_post_act": None, |
| 65 | + "transformer_layers_per_block": [ |
| 66 | + 1, |
| 67 | + 2, |
| 68 | + 10 |
| 69 | + ], |
| 70 | + "up_block_types": [ |
| 71 | + "CrossAttnUpBlock2D", |
| 72 | + "CrossAttnUpBlock2D", |
| 73 | + "UpBlock2D" |
| 74 | + ], |
| 75 | + "upcast_attention": None, |
| 76 | + "use_linear_projection": True |
| 77 | +} |
| 78 | + |
11 | 79 | CONTROLNET_CONFIG = {
|
12 | 80 | 'in_channels': [128, 128],
|
13 | 81 | 'out_channels': [128, 256],
|
@@ -83,19 +151,22 @@ def get_pipeline(
|
83 | 151 |
|
84 | 152 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
|
85 | 153 | if unet_model_name_or_path is not None:
|
| 154 | + print(f"loading controlnext unet from {unet_model_name_or_path}") |
86 | 155 | pipeline.load_controlnext_unet_weights(
|
87 | 156 | unet_model_name_or_path,
|
88 | 157 | load_weight_increasement=load_weight_increasement,
|
89 | 158 | use_safetensors=True,
|
90 | 159 | torch_dtype=torch.float16,
|
91 | 160 | cache_dir=hf_cache_dir,
|
92 | 161 | )
|
93 |
| - pipeline.load_controlnext_controlnet_weights( |
94 |
| - controlnet_model_name_or_path, |
95 |
| - use_safetensors=True, |
96 |
| - torch_dtype=torch.float32, |
97 |
| - cache_dir=hf_cache_dir, |
98 |
| - ) |
| 162 | + if controlnet_model_name_or_path is not None: |
| 163 | + print(f"loading controlnext controlnet from {controlnet_model_name_or_path}") |
| 164 | + pipeline.load_controlnext_controlnet_weights( |
| 165 | + controlnet_model_name_or_path, |
| 166 | + use_safetensors=True, |
| 167 | + torch_dtype=torch.float32, |
| 168 | + cache_dir=hf_cache_dir, |
| 169 | + ) |
99 | 170 | pipeline.set_progress_bar_config()
|
100 | 171 | pipeline = pipeline.to(device, dtype=torch.float16)
|
101 | 172 |
|
|
0 commit comments