Skip to content

Commit 0463e81

Browse files
committed
lora fix
1 parent 9fadc6a commit 0463e81

7 files changed

+937
-348
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,6 @@ temp/RD_decode_temp.png
5151
temp/RD_encode_temp.png
5252
temp/RD_mask.png
5353
venv
54+
models/base/BLIP/*
55+
models/base/LECO/*
56+
models/base/RDGenerationLog.txt

scripts/cldm_inference.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from ldm.model_management import unload_all_models
77
from ldm.lora import load_lora_for_models
88
from ldm.sd import load_checkpoint_guess_config
9-
10-
import copy
11-
from PIL import Image, ImageOps
9+
from PIL import ImageOps
1210
import numpy as np
1311
import torch
1412

@@ -62,7 +60,7 @@ def load_controlnet(
6260
output_clip=False,
6361
output_clipvision=False,
6462
)
65-
63+
6664
model_patcher = out[0]
6765

6866
# Apply loras
@@ -86,6 +84,9 @@ def load_controlnet(
8684

8785
# Apply controlnet to conditioning
8886
(cldm_conditioning,) = apply_controlnet(cldm_conditioning, controlnet, image, controlnet_input["weight"])
87+
88+
# Patch the model
89+
lora_model_patcher.patch_model()
8990

9091
return lora_model_patcher, cldm_conditioning, cldm_negative_conditioning
9192

@@ -131,4 +132,4 @@ def unload_cldm():
131132
# Unload the model
132133
unload_all_models()
133134

134-
return
135+
return

scripts/ldm/model_base.py

+247-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import ldm.conds
44
from enum import Enum
55

6+
import ldm.ops
7+
import ldm.model_management
8+
69
from ldm.cldm_models import UNetModel
710
from . import utils
811

@@ -33,6 +36,247 @@ class ModelSampling(s, c):
3336

3437

3538
class BaseModel(torch.nn.Module):
39+
def __init__(
40+
self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel
41+
):
42+
super().__init__()
43+
44+
unet_config = model_config.unet_config
45+
self.latent_format = model_config.latent_format
46+
self.model_config = model_config
47+
self.manual_cast_dtype = model_config.manual_cast_dtype
48+
49+
if not unet_config.get("disable_unet_model_creation", False):
50+
if self.manual_cast_dtype is not None:
51+
operations = ldm.ops.manual_cast
52+
else:
53+
operations = ldm.ops.disable_weight_init
54+
self.diffusion_model = unet_model(
55+
**unet_config, device=device, operations=operations
56+
)
57+
self.model_type = model_type
58+
self.model_sampling = model_sampling(model_config, model_type)
59+
60+
self.adm_channels = unet_config.get("adm_in_channels", None)
61+
if self.adm_channels is None:
62+
self.adm_channels = 0
63+
self.inpaint_model = False
64+
print("model_type", model_type.name)
65+
print("adm", self.adm_channels)
66+
67+
def apply_model(
68+
self,
69+
x,
70+
t,
71+
c_concat=None,
72+
c_crossattn=None,
73+
control=None,
74+
transformer_options={},
75+
**kwargs
76+
):
77+
sigma = t
78+
xc = self.model_sampling.calculate_input(sigma, x)
79+
if c_concat is not None:
80+
xc = torch.cat([xc] + [c_concat], dim=1)
81+
82+
context = c_crossattn
83+
dtype = self.get_dtype()
84+
85+
if self.manual_cast_dtype is not None:
86+
dtype = self.manual_cast_dtype
87+
88+
xc = xc.to(dtype)
89+
t = self.model_sampling.timestep(t).float()
90+
context = context.to(dtype)
91+
extra_conds = {}
92+
for o in kwargs:
93+
extra = kwargs[o]
94+
if hasattr(extra, "dtype"):
95+
if extra.dtype != torch.int and extra.dtype != torch.long:
96+
extra = extra.to(dtype)
97+
extra_conds[o] = extra
98+
99+
model_output = self.diffusion_model(
100+
xc,
101+
t,
102+
context=context,
103+
control=control,
104+
transformer_options=transformer_options,
105+
**extra_conds
106+
).float()
107+
return self.model_sampling.calculate_denoised(sigma, model_output, x)
108+
109+
def get_dtype(self):
110+
return self.diffusion_model.dtype
111+
112+
def is_adm(self):
113+
return self.adm_channels > 0
114+
115+
def encode_adm(self, **kwargs):
116+
return None
117+
118+
def extra_conds(self, **kwargs):
119+
out = {}
120+
if self.inpaint_model:
121+
concat_keys = ("mask", "masked_image")
122+
cond_concat = []
123+
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
124+
concat_latent_image = kwargs.get("concat_latent_image", None)
125+
if concat_latent_image is None:
126+
concat_latent_image = kwargs.get("latent_image", None)
127+
else:
128+
concat_latent_image = self.process_latent_in(concat_latent_image)
129+
130+
noise = kwargs.get("noise", None)
131+
device = kwargs["device"]
132+
133+
if concat_latent_image.shape[1:] != noise.shape[1:]:
134+
concat_latent_image = utils.common_upscale(
135+
concat_latent_image,
136+
noise.shape[-1],
137+
noise.shape[-2],
138+
"bilinear",
139+
"center",
140+
)
141+
142+
concat_latent_image = utils.resize_to_batch_size(
143+
concat_latent_image, noise.shape[0]
144+
)
145+
146+
if len(denoise_mask.shape) == len(noise.shape):
147+
denoise_mask = denoise_mask[:, :1]
148+
149+
denoise_mask = denoise_mask.reshape(
150+
(-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])
151+
)
152+
if denoise_mask.shape[-2:] != noise.shape[-2:]:
153+
denoise_mask = utils.common_upscale(
154+
denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center"
155+
)
156+
denoise_mask = utils.resize_to_batch_size(
157+
denoise_mask.round(), noise.shape[0]
158+
)
159+
160+
def blank_inpaint_image_like(latent_image):
161+
blank_image = torch.ones_like(latent_image)
162+
# these are the values for "zero" in pixel space translated to latent space
163+
blank_image[:, 0] *= 0.8223
164+
blank_image[:, 1] *= -0.6876
165+
blank_image[:, 2] *= 0.6364
166+
blank_image[:, 3] *= 0.1380
167+
return blank_image
168+
169+
for ck in concat_keys:
170+
if denoise_mask is not None:
171+
if ck == "mask":
172+
cond_concat.append(denoise_mask.to(device))
173+
elif ck == "masked_image":
174+
cond_concat.append(
175+
concat_latent_image.to(device)
176+
) # NOTE: the latent_image should be masked by the mask in pixel space
177+
else:
178+
if ck == "mask":
179+
cond_concat.append(torch.ones_like(noise)[:, :1])
180+
elif ck == "masked_image":
181+
cond_concat.append(blank_inpaint_image_like(noise))
182+
data = torch.cat(cond_concat, dim=1)
183+
out["c_concat"] = ldm.conds.CONDNoiseShape(data)
184+
185+
adm = self.encode_adm(**kwargs)
186+
if adm is not None:
187+
out["y"] = ldm.conds.CONDRegular(adm)
188+
189+
cross_attn = kwargs.get("cross_attn", None)
190+
if cross_attn is not None:
191+
out["c_crossattn"] = ldm.conds.CONDCrossAttn(cross_attn)
192+
193+
cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
194+
if cross_attn_cnet is not None:
195+
out["crossattn_controlnet"] = ldm.conds.CONDCrossAttn(cross_attn_cnet)
196+
197+
return out
198+
199+
def load_model_weights(self, sd, unet_prefix=""):
200+
to_load = {}
201+
keys = list(sd.keys())
202+
for k in keys:
203+
if k.startswith(unet_prefix):
204+
to_load[k[len(unet_prefix) :]] = sd.pop(k)
205+
206+
to_load = self.model_config.process_unet_state_dict(to_load)
207+
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
208+
if len(m) > 0:
209+
print("unet missing:", m)
210+
211+
if len(u) > 0:
212+
print("unet unexpected:", u)
213+
del to_load
214+
return self
215+
216+
def process_latent_in(self, latent):
217+
return self.latent_format.process_in(latent)
218+
219+
def process_latent_out(self, latent):
220+
return self.latent_format.process_out(latent)
221+
222+
def state_dict_for_saving(
223+
self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None
224+
):
225+
extra_sds = []
226+
if clip_state_dict is not None:
227+
extra_sds.append(
228+
self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
229+
)
230+
if vae_state_dict is not None:
231+
extra_sds.append(
232+
self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
233+
)
234+
if clip_vision_state_dict is not None:
235+
extra_sds.append(
236+
self.model_config.process_clip_vision_state_dict_for_saving(
237+
clip_vision_state_dict
238+
)
239+
)
240+
241+
unet_state_dict = self.diffusion_model.state_dict()
242+
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(
243+
unet_state_dict
244+
)
245+
246+
if self.get_dtype() == torch.float16:
247+
extra_sds = map(
248+
lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds
249+
)
250+
251+
if self.model_type == ModelType.V_PREDICTION:
252+
unet_state_dict["v_pred"] = torch.tensor([])
253+
254+
for sd in extra_sds:
255+
unet_state_dict.update(sd)
256+
257+
return unet_state_dict
258+
259+
def set_inpaint(self):
260+
self.inpaint_model = True
261+
262+
def memory_required(self, input_shape):
263+
if (
264+
ldm.model_management.xformers_enabled()
265+
or ldm.model_management.pytorch_attention_flash_attention()
266+
):
267+
dtype = self.get_dtype()
268+
if self.manual_cast_dtype is not None:
269+
dtype = self.manual_cast_dtype
270+
# TODO: this needs to be tweaked
271+
area = input_shape[0] * input_shape[2] * input_shape[3]
272+
return (area * ldm.model_management.dtype_size(dtype) / 50) * (
273+
1024 * 1024
274+
)
275+
else:
276+
# TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
277+
area = input_shape[0] * input_shape[2] * input_shape[3]
278+
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
279+
36280
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
37281
super().__init__()
38282

@@ -162,10 +406,10 @@ def set_inpaint(self):
162406

163407
def memory_required(self, input_shape):
164408
if ldm.model_management.xformers_enabled() or ldm.model_management.pytorch_attention_flash_attention():
165-
#TODO: this needs to be tweaked
409+
# TODO: this needs to be tweaked
166410
area = input_shape[0] * input_shape[2] * input_shape[3]
167411
return (area * ldm.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
168412
else:
169-
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
413+
# TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
170414
area = input_shape[0] * input_shape[2] * input_shape[3]
171-
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
415+
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)

scripts/ldm/model_management.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -309,31 +309,41 @@ def model_load(self, lowvram_model_memory=0):
309309

310310
if lowvram_model_memory > 0:
311311
print("loading in lowvram mode", lowvram_model_memory / (1024 * 1024))
312-
device_map = accelerate.infer_auto_device_map(
313-
self.real_model,
314-
max_memory={
315-
0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)),
316-
"cpu": "16GiB",
317-
},
318-
)
319-
accelerate.dispatch_model(
320-
self.real_model, device_map=device_map, main_device=self.device
321-
)
312+
mem_counter = 0
313+
for m in self.real_model.modules():
314+
if hasattr(m, "comfy_cast_weights"):
315+
m.prev_comfy_cast_weights = m.comfy_cast_weights
316+
m.comfy_cast_weights = True
317+
module_mem = module_size(m)
318+
if mem_counter + module_mem < lowvram_model_memory:
319+
m.to(self.device)
320+
mem_counter += module_mem
321+
elif hasattr(
322+
m, "weight"
323+
): # only modules with comfy_cast_weights can be set to lowvram mode
324+
m.to(self.device)
325+
mem_counter += module_size(m)
326+
print("lowvram: loaded module regularly", m)
327+
322328
self.model_accelerated = True
323329

324-
if is_intel_xpu() and not disable_ipex_optimize:
325-
self.real_model = torch.xpu.optimize(
326-
self.real_model.eval(),
327-
inplace=True,
328-
auto_kernel_selection=True,
329-
graph_mode=True,
330-
)
330+
# if is_intel_xpu() and not args.disable_ipex_optimize:
331+
# self.real_model = torch.xpu.optimize(
332+
# self.real_model.eval(),
333+
# inplace=True,
334+
# auto_kernel_selection=True,
335+
# graph_mode=True,
336+
# )
331337

332338
return self.real_model
333339

334340
def model_unload(self):
335341
if self.model_accelerated:
336-
accelerate.hooks.remove_hook_from_submodules(self.real_model)
342+
for m in self.real_model.modules():
343+
if hasattr(m, "prev_comfy_cast_weights"):
344+
m.comfy_cast_weights = m.prev_comfy_cast_weights
345+
del m.prev_comfy_cast_weights
346+
337347
self.model_accelerated = False
338348

339349
self.model.unpatch_model(self.model.offload_device)

0 commit comments

Comments
 (0)