diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7af113d --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +*.safetensors + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..13ee2b0 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "nuxt.isNuxtApp": false +} \ No newline at end of file diff --git a/SystemCode/src/main/model/__init__.py b/SystemCode/src/main/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/SystemCode/src/main/model/attention.py b/SystemCode/src/main/model/attention.py new file mode 100644 index 0000000..0336039 --- /dev/null +++ b/SystemCode/src/main/model/attention.py @@ -0,0 +1,275 @@ +import math +from einops import rearrange +from torch import einsum + +import torch +import diffusers + +# Script adapted from https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py + +# Paper: https://arxiv.org/abs/2205.14135 + +# constants + +EPSILON = 1e-6 + + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# flash attention forwards and backwards +class FlashAttentionFunction(torch.autograd.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) + dp = einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +""" +Replace UNet Module with faster, memory efficient transformers. +xformers works with Linux for now. Currently, unavailable for Windows. +""" +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + + # XFormers is recommended for Linux machines, there is no XFormers on conda for Windows + elif xformers: + replace_unet_cross_attn_to_xformers() + + +def replace_unet_cross_attn_to_memory_efficient(): + print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + + context = context if context is not None else x + context = context.to(x.dtype) + + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_flash_attn + + + +def replace_unet_cross_attn_to_xformers(): + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + + context = default(context, x) + context = context.to(x.dtype) + + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers \ No newline at end of file diff --git a/SystemCode/src/main/model/generate.py b/SystemCode/src/main/model/generate.py new file mode 100644 index 0000000..115a5ac --- /dev/null +++ b/SystemCode/src/main/model/generate.py @@ -0,0 +1,1409 @@ +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import time + +import argparse +import math +import os +import random +import re + +import diffusers +import numpy as np +import torch +import torchvision +from diffusers import ( + AutoencoderKL, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + StableDiffusionPipeline, +) + +from transformers import CLIPModel +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + + +from attention import FlashAttentionFunction, replace_unet_modules +import model as model_util +import train as train_util +from pipeline import PipelineLike +import tools.original_control_net as original_control_net +from tools.original_control_net import ControlNetInfo + +TOKENIZER_PATH = "openai/clip-vit-large-patch14" +DEFAULT_TOKEN_LENGTH = 75 + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +# CLIP +CLIP_MODEL_PATH = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" +FEATURE_EXTRACTOR_SIZE = (224, 224) +FEATURE_EXTRACTOR_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073] +FEATURE_EXTRACTOR_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711] + +VGG16_IMAGE_MEAN = [0.485, 0.456, 0.406] +VGG16_IMAGE_STD = [0.229, 0.224, 0.225] +VGG16_INPUT_RESIZE_DIV = 4 +NUM_CUTOUTS = 4 +USE_CUTOUTS = False + + + + + +class MakeCutouts(torch.nn.Module): + def __init__(self, cut_size, cut_power=1.0): + super().__init__() + + self.cut_size = cut_size + self.cut_power = cut_power + + def forward(self, pixel_values, num_cutouts): + sideY, sideX = pixel_values.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(num_cutouts): + size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size] + cutouts.append(torch.nn.functional.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) + + +def spherical_dist_loss(x, y): + x = torch.nn.functional.normalize(x, dim=-1) + y = torch.nn.functional.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + + token = pipe.replace_token(token, layer=layer) + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe: PipelineLike, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + if clip_skip is None or clip_skip == 1: + text_embedding = pipe.text_encoder(text_input_chunk)[0] + else: + enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + if clip_skip is None or clip_skip == 1: + text_embeddings = pipe.text_encoder(text_input)[0] + else: + enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) + return text_embeddings + + +def get_weighted_text_embeddings( + pipe: PipelineLike, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, + layer=None, + **kwargs, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + Args: + pipe (`DiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `1`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer) + else: + prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + pad = pipe.tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings, prompt_tokens + return text_embeddings, None, prompt_tokens + + +def preprocess_guide_image(image): + image = image.resize(FEATURE_EXTRACTOR_SIZE, resample=Image.NEAREST) # Combine with cond_fn + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 + + +# VGG16 input can be any size, so resize the input image appropriately +def preprocess_vgg16_guide_image(image, size): + image = image.resize(size, resample=Image.NEAREST) # Combine with cond_fn + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + +class BatchDataBase(NamedTuple): + step: int + prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + + +class BatchDataExt(NamedTuple): + # Data requiring batch splitting + width: int + height: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + num_sub_prompts: int + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + if not os.path.isfile(args.ckpt): + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + + use_stable_diffusion_format = os.path.isfile(args.ckpt) + if use_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) + else: + print("load Diffusers pretrained models") + loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + text_encoder = loading_pipe.text_encoder + vae = loading_pipe.vae + unet = loading_pipe.unet + tokenizer = loading_pipe.tokenizer + del loading_pipe + + if args.vae is not None: + vae = model_util.load_vae(args.vae, dtype) + print("additional VAE loaded") + + if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: + print("prepare clip model") + clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) + else: + clip_model = None + + if args.vgg16_guidance_scale > 0.0: + print("prepare resnet model") + vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) + else: + vgg16_model = None + + # xformers、Hypernetwork (for Linux, not Windows) + if not args.diffusers_xformers: + replace_unet_modules(unet, not args.xformers, args.xformers) + + print("loading tokenizer") + if use_stable_diffusion_format: + tokenizer = train_util.load_tokenizer(args) + + sched_init_args = {} + scheduler_num_noises_per_step = 1 + if args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + else: + scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + print("set clip_sample to True") + scheduler.config.clip_sample = True + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # copy custom pipeline + vae.to(dtype).to(device) + text_encoder.to(dtype).to(device) + unet.to(dtype).to(device) + if clip_model is not None: + clip_model.to(dtype).to(device) + if vgg16_model is not None: + vgg16_model.to(dtype).to(device) + + # import network modules + if args.network_module: + networks = [] + network_default_muls = [] + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_default_muls.append(network_mul) + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights and i < len(args.network_weights): + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs + ) + else: + raise ValueError("No weight. Weight is required.") + if network is None: + return + + mergeable = hasattr(network, "merge_to") + if args.network_merge and not mergeable: + print("network is not mergeable. ignore merge option.") + + if not args.network_merge or not mergeable: + network.apply_to(text_encoder, unet) + info = network.load_state_dict(weights_sd, False) + print(f"weights are loaded: {info}") + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + networks.append(network) + else: + network.merge_to(text_encoder, unet, weights_sd, dtype, device) + + else: + networks = [] + + # ControlNet + control_nets: List[ControlNetInfo] = [] + if args.control_net_models: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + + if args.opt_channels_last: + print(f"set optimizing: channels last") + text_encoder.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if clip_model is not None: + clip_model.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + if vgg16_model is not None: + vgg16_model.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.unet.to(memory_format=torch.channels_last) + cn.net.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + device, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + args.clip_skip, + clip_model, + args.clip_guidance_scale, + args.clip_image_guidance_scale, + vgg16_model, + args.vgg16_guidance_scale, + args.vgg16_guidance_layer, + ) + pipe.set_control_nets(control_nets) + print("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + if args.textual_inversion_embeddings: + token_ids_embeds = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + embeds = next(iter(data.values())) + + if type(embeds) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor") + + num_vectors_per_token = embeds.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == num_vectors_per_token + ), f"tokenizer has same word to token string (filename). please rename the file" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + assert ( + min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 + ), f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(token_ids[0], token_ids) + + token_ids_embeds.append((token_ids, embeds)) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds: + for token_id, embed in zip(token_ids, embeds): + token_embeds[token_id] = embed + + if args.XTI_embeddings: + XTI_layers = [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ] + token_ids_embeds_XTI = [] + for embeds_file in args.XTI_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + if set(data.keys()) != set(XTI_layers): + raise ValueError("NOT XTI") + embeds = torch.concat(list(data.values())) + num_vectors_per_token = data["MID"].size()[0] + + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == num_vectors_per_token + ), f"tokenizer has same word to token string (filename). please rename the file" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + + # if num_vectors_per_token > 1: + pipe.add_token_replacement(token_ids[0], token_ids) + + token_strings_XTI = [] + for layer_name in XTI_layers: + token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] + tokenizer.add_tokens(token_strings_XTI) + token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) + token_ids_embeds_XTI.append((token_ids_XTI, embeds)) + for t in token_ids: + t_XTI_dic = {} + for i, layer_name in enumerate(XTI_layers): + t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens + pipe.add_token_replacement_XTI(t, t_XTI_dic) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds_XTI: + for token_id, embed in zip(token_ids, embeds): + token_embeds[token_id] = embed + + # promptを取得する + if args.from_file is not None: + print(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + elif args.prompt is not None: + prompt_list = [args.prompt] + else: + prompt_list = [] + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + print(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + print(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image" + print(f"loaded {len(init_images)} images for img2img") + else: + init_images = None + + if args.mask_path is not None: + print(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image" + print(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and len(prompt_list) == 0 and not args.interactive: + print("get prompts from images' meta data") + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + if init_images is not None: + print(f"resize img2img source images to {args.W}*{args.H}") + init_images = resize_images(init_images, (args.W, args.H)) + if mask_images is not None: + print(f"resize img2img mask images to {args.W}*{args.H}") + mask_images = resize_images(mask_images, (args.W, args.H)) + + regional_network = False + if networks and mask_images: + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True + print("use mask as region") + + size = None + for i, network in enumerate(networks): + if i < 3: + np_mask = np.array(mask_images[0]) + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + print(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + print(f"No guide image, use previous generated image.") + guide_images = None + else: + guide_images = None + + # seed指定時はseedを決めておく + if args.seed is not None: + random.seed(args.seed) + predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] + if len(predefined_seeds) == 1: + predefined_seeds[0] = args.seed + else: + predefined_seeds = None + + if args.W is None: + args.W = 512 + if args.H is None: + args.H = 512 + + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + print(f"iteration {gen_iter+1}/{args.n_iter}") + iter_seed = random.randint(0, 0x7FFFFFFF) + + def process_batch(batch: List[BatchData]): + batch_size = len(batch) + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image), + (width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + # Generate a random number here to use the same random number regardless of the position in the batch. Also check if image/mask is the same in batch. + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, (_, (_, prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + prompts.append(prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + shared = {} + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + if regional_network: + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + )[0] + # if highres_1st and not args.highres_fix_save_1st: # return images or latents + # return images + + # save image + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts) + ): + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + print("opencv-python is not installed, cannot preview") + + return images + + # Loop of image generation prompts + prompt_index = 0 + global_step = 0 + batch_data = [] + while args.interactive or prompt_index < len(prompt_list): + if len(prompt_list) == 0: + # interactive + valid = False + while not valid: + print("\nType prompt:") + try: + prompt = input() + except EOFError: + break + + valid = len(prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + prompt = prompt_list[prompt_index] + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data) + batch_data.clear() + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model") + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training" + ) + parser.add_argument("--prompt", type=str, default=None, help="prompt") + parser.add_argument( + "--from_file", type=str, default=None, help="if specified, load prompts from this file" + ) + parser.add_argument( + "--interactive", action="store_true", help="interactive mode (generates one image)" + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting") + parser.add_argument("--strength", type=float, default=None, help="img2img strength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to") + parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name") + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space") + parser.add_argument("--batch_size", type=int, default=1, help="batch size") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio", + ) + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace" + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16") + parser.add_argument("--bf16", action="store_true", help="use bfloat16") + parser.add_argument("--xformers", action="store_true", help="use xformers") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work)", + ) + parser.add_argument( + "--opt_channels_last", action="store_true", help="set channels last option to model" + ) + parser.add_argument( + "--network_module", type=str, default=None, nargs="*", help="additional network module to use" + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load" + ) + parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier") + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value)" + ) + parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model") + parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model") + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion", + ) + parser.add_argument( + "--XTI_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Extended Textual Inversion", + ) + parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP") + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embeding multiples, max token length is 75 * multiples", + ) + parser.add_argument( + "--clip_guidance_scale", + type=float, + default=0.0, + help="enable CLIP guided SD, scale for guidance (DDIM, PNDM, LMS samplers only)", + ) + parser.add_argument( + "--clip_image_guidance_scale", + type=float, + default=0.0, + help="enable CLIP guided SD by image, scale for guidance", + ) + parser.add_argument( + "--vgg16_guidance_scale", + type=float, + default=0.0, + help="enable VGG16 guided SD by image, scale for guidance", + ) + parser.add_argument( + "--vgg16_guidance_layer", + type=int, + default=20, + help="layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2)", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage", + ) + parser.add_argument( + "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix" + ) + parser.add_argument( + "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix" + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix", + ) + parser.add_argument( + "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt" + ) + + parser.add_argument( + "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use" + ) + parser.add_argument( + "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use" + ) + parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights") + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps", + ) + # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance" + # ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/SystemCode/src/main/model/model.py b/SystemCode/src/main/model/model.py new file mode 100644 index 0000000..9a8a2b1 --- /dev/null +++ b/SystemCode/src/main/model/model.py @@ -0,0 +1,306 @@ +import math +import os +import torch +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from utils.conversions import convert_ldm_vae_checkpoint, convert_ldm_clip_checkpoint_v1, convert_vae_state_dict +from safetensors.torch import load_file, save_file + +# Utils +from utils.conversions import convert_unet_state_dict_to_sd, convert_text_encoder_key, convert_ldm_unet_checkpoint +from utils.common import get_file_type + +UNET_PARAMS_MODEL_CHANNELS = 320 +UNET_PARAMS_CHANNEL_MULTI = [1, 2, 4, 4] +UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] +UNET_PARAMS_IMAGE_SIZE = 64 +UNET_PARAMS_IN_CHANNELS = 4 +UNET_PARAMS_OUT_CHANNELS = 4 +UNET_PARAMS_NUM_RES_BLOCKS = 2 +UNET_PARAMS_CONTEXT_DIM = 768 +UNET_PARAMS_NUM_HEADS = 8 + +VAE_PARAMS_Z_CHANNELS = 4 +VAE_PARAMS_RESOLUTION = 256 +VAE_PARAMS_IN_CHANNELS = 3 +VAE_PARAMS_OUT_CH = 3 +VAE_PARAMS_CH = 128 +VAE_PARAMS_CH_MULT = [1, 2, 4, 4] +VAE_PARAMS_NUM_RES_BLOCKS = 2 + +MODEL_PATH = "checkpoints/v1-5-pruned-emaonly.safetensors" + + +def create_unet_diffusers_config(): + """ + Creates a config for the diffusers based on the config of the LDM model. + Uses diffusers' UNet2DModel. Documentation: https://huggingface.co/docs/diffusers/api/models/unet2d + """ + + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULTI] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS, + ) + + return config + + +def create_vae_diffusers_config(): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config + + +def is_safetensors(path): + return get_file_type(path) == ".safetensors" + +def load_checkpoint(checkpoint_path, device): + if is_safetensors(checkpoint_path): + checkpoint = None + state_dict = load_file(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None + return checkpoint, state_dict + + + +def load_models_from_stable_diffusion_checkpoint(checkpoint_path, device="cpu", dtype=None): + _, state_dict = load_checkpoint(checkpoint_path, device) + state_dict = convert_text_encoder_key(state_dict) + + # UNet model. + unet_config = create_unet_diffusers_config() + converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config) + + unet = UNet2DConditionModel(**unet_config).to(device) + info = unet.load_state_dict(converted_unet_checkpoint) + print("load u-net:", info) + + # VAE model. + vae_config = create_vae_diffusers_config() + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + + vae = AutoencoderKL(**vae_config).to(device) + info = vae.load_state_dict(converted_vae_checkpoint) + print("load vae:", info) + + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) + + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) + logging.set_verbosity_warning() + + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + print("load text encoder:", info) + + return text_model, vae, unet + + + +def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, checkpoint_path, epochs, steps, save_dtype=None, vae=None): + if checkpoint_path is not None: + checkpoint, state_dict = load_checkpoint(checkpoint_path) + state_dict = convert_text_encoder_key(state_dict) + + if checkpoint is None: + checkpoint = {} + strict = False + else: + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] + else: + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Reconvert the checkpoint from diffusers format to Stable Diffusion LDM format + # Convert UNet + unet_state_dict = convert_unet_state_dict_to_sd(unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) + + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + + # Convert VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Create new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + try: + if "epoch" in checkpoint: + epochs += checkpoint["epoch"] + if "global_step" in checkpoint: + steps += checkpoint["global_step"] + except: + pass + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if is_safetensors(output_file): + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint(output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): + if pretrained_model_name_or_path is None: + pretrained_model_name_or_path = MODEL_PATH + + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) + + +VAE_PREFIX = "first_stage_model." + + +def load_vae(vae_id, dtype): + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu") + vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +def make_bucket_resolutions(max_reso, min_size=256, max_size=512, divisible=64): + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) + + resos = set() + + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) + + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) + size += divisible + + resos = list(resos) + resos.sort() + return resos + + +if __name__ == "__main__": + resos = make_bucket_resolutions((512, 512)) + print(len(resos)) + print(resos) + aspect_ratios = [w / h for w, h in resos] + print(aspect_ratios) + + ars = set() + for ar in aspect_ratios: + if ar in ars: + print("error! duplicate ar:", ar) + ars.add(ar) diff --git a/SystemCode/src/main/model/networks/lora.py b/SystemCode/src/main/model/networks/lora.py new file mode 100644 index 0000000..227b36a --- /dev/null +++ b/SystemCode/src/main/model/networks/lora.py @@ -0,0 +1,938 @@ +# LoRA +# Adapted from: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py +# https://github.com/Linaqruf/kohya-trainer/blob/main/networks/lora.py + +# Guided by explanation in this presentation: https://www.youtube.com/watch?v=PXWYUTMt-AU + +import math +import os +from typing import List, Tuple, Union +import numpy as np +import torch +import re + + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original LoraInjectedLinear + """ + + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def merge_to(self, sd, dtype, device): + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + + # extract weight from original module, org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"].to(torch.float) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + def set_region(self, region): + self.region = region + self.region_mask = None + + def forward(self, x): + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + +class LoRAInferenceModule(LoRAModule): + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + # check regional or not by lora_name + self.text_encoder = False + if lora_name.startswith("lora_te_"): + self.regional = False + self.use_sub_prompt = True + self.text_encoder = True + elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name: + self.regional = False + self.use_sub_prompt = True + elif "time_emb" in lora_name: + self.regional = False + self.use_sub_prompt = False + else: + self.regional = True + self.use_sub_prompt = False + + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + def default_forward(self, x): + # print("default_forward", self.lora_name, x.size()) + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + def forward(self, x): + if self.network is None or self.network.sub_prompt_index is None: + return self.default_forward(x) + if not self.regional and not self.use_sub_prompt: + return self.default_forward(x) + + if self.regional: + return self.regional_forward(x) + else: + return self.sub_prompt_forward(x) + + def get_mask_for_x(self, x): + # calculate size from shape of x + if len(x.size()) == 4: + h, w = x.size()[2:4] + area = h * w + else: + area = x.size()[1] + + mask = self.network.mask_dic[area] + if mask is None: + raise ValueError(f"mask is None for resolution {area}") + if len(x.size()) != 4: + mask = torch.reshape(mask, (1, -1, 1)) + return mask + + def regional_forward(self, x): + if "attn2_to_out" in self.lora_name: + return self.to_out_forward(x) + + if self.network.mask_dic is None: # sub_prompt_index >= 3 + return self.default_forward(x) + + # apply mask for LoRA result + lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + mask = self.get_mask_for_x(lx) + # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + lx = lx * mask + + x = self.org_forward(x) + x = x + lx + + if "attn2_to_q" in self.lora_name and self.network.is_last_network: + x = self.postp_to_q(x) + + return x + + def postp_to_q(self, x): + # repeat x to num_sub_prompts + has_real_uncond = x.size()[0] // self.network.batch_size == 3 + qc = self.network.batch_size # uncond + qc += self.network.batch_size * self.network.num_sub_prompts # cond + if has_real_uncond: + qc += self.network.batch_size # real_uncond + + query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype) + query[: self.network.batch_size] = x[: self.network.batch_size] + + for i in range(self.network.batch_size): + qi = self.network.batch_size + i * self.network.num_sub_prompts + query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i] + + if has_real_uncond: + query[-self.network.batch_size :] = x[-self.network.batch_size :] + + # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + return query + + def sub_prompt_forward(self, x): + if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA + return self.org_forward(x) + + emb_idx = self.network.sub_prompt_index + if not self.text_encoder: + emb_idx += self.network.batch_size + + # apply sub prompt of X + lx = x[emb_idx :: self.network.num_sub_prompts] + lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale + + # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + + x = self.org_forward(x) + x[emb_idx :: self.network.num_sub_prompts] += lx + + return x + + def to_out_forward(self, x): + # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + + if self.network.is_last_network: + masks = [None] * self.network.num_sub_prompts + self.network.shared[self.lora_name] = (None, masks) + else: + lx, masks = self.network.shared[self.lora_name] + + # call own LoRA + x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts] + lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale + + if self.network.is_last_network: + lx = torch.zeros( + (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype + ) + self.network.shared[self.lora_name] = (lx, masks) + + # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 + masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) + + # if not last network, return x and masks + x = self.org_forward(x) + if not self.network.is_last_network: + return x + + lx, masks = self.network.shared.pop(self.lora_name) + + # if last network, combine separated x with mask weighted sum + has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2 + + out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype) + out[: self.network.batch_size] = x[: self.network.batch_size] # uncond + if has_real_uncond: + out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond + + mask = torch.cat(masks) + mask_sum = torch.sum(mask, dim=0) + 1e-4 + for i in range(self.network.batch_size): + # Process each image individually + lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts] + lx1 = lx1 * mask + lx1 = torch.sum(lx1, dim=0) + + xi = self.network.batch_size + i * self.network.num_sub_prompts + x1 = x[xi : xi + self.network.num_sub_prompts] + x1 = x1 * mask + x1 = torch.sum(x1, dim=0) + x1 = x1 / mask_sum + + x1 = x1 + lx1 + out[self.network.batch_size + i] = x1 + + # print("to_out_forward", x.size(), out.size(), has_real_uncond) + return out + + +def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # block dim/alpha/lr + block_dims = kwargs.get("block_dims", None) + down_lr_weight = kwargs.get("down_lr_weight", None) + mid_lr_weight = kwargs.get("mid_lr_weight", None) + up_lr_weight = kwargs.get("up_lr_weight", None) + + # If any of the above is specified, enable dim (rank) for each block. + if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: + block_alphas = kwargs.get("block_alphas", None) + conv_block_dims = kwargs.get("conv_block_dims", None) + conv_block_alphas = kwargs.get("conv_block_alphas", None) + + block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( + block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + ) + + # extract learning rate weight for each block + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + # remove block dim/alpha without learning rate + block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( + block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + ) + + else: + block_alphas = None + conv_block_dims = None + conv_block_alphas = None + + network = LoRANetwork( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + block_dims=block_dims, + block_alphas=block_alphas, + conv_block_dims=conv_block_dims, + conv_block_alphas=conv_block_alphas, + varbose=True, + ) + + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + + return network + + +def get_block_dims_and_alphas( + block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha +): + num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1 + + def parse_ints(s): + return [int(i) for i in s.split(",")] + + def parse_floats(s): + return [float(i) for i in s.split(",")] + + if block_dims is not None: + block_dims = parse_ints(block_dims) + assert ( + len(block_dims) == num_total_blocks + ), f"block_dims must have {num_total_blocks} elements" + else: + print(f"block_dims is not specified. all dims are set to {network_dim}") + block_dims = [network_dim] * num_total_blocks + + if block_alphas is not None: + block_alphas = parse_floats(block_alphas) + assert ( + len(block_alphas) == num_total_blocks + ), f"block_alphas must have {num_total_blocks} elements" + else: + print( + f"block_alphas is not specified. all alphas are set to {network_alpha}" + ) + block_alphas = [network_alpha] * num_total_blocks + + # conv_block_dims and conv_block_alphas only if specified. If not specified, use conv_dim and conv_alpha + if conv_block_dims is not None: + conv_block_dims = parse_ints(conv_block_dims) + assert ( + len(conv_block_dims) == num_total_blocks + ), f"conv_block_dims must have {num_total_blocks} elements" + + if conv_block_alphas is not None: + conv_block_alphas = parse_floats(conv_block_alphas) + assert ( + len(conv_block_alphas) == num_total_blocks + ), f"conv_block_alphas must have {num_total_blocks} elements" + else: + if conv_alpha is None: + conv_alpha = 1.0 + print( + f"conv_block_alphas is not specified. all alphas are set to {conv_alpha}" + ) + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + if conv_dim is not None: + print( + f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha}" + ) + conv_block_dims = [conv_dim] * num_total_blocks + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + conv_block_dims = None + conv_block_alphas = None + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# Define the multiplier for the learning rate for each layer for layered learning rate, taking into account the possibility of being called externally. +def get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold +) -> Tuple[List[float], List[float], List[float]]: + # If no parameters are specified, do nothing and behave as before. + if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: + return None, None, None + + max_len = LoRANetwork.NUM_OF_BLOCKS # Number of up and down layers equivalent to full model + + def get_list(name_with_suffix) -> List[float]: + import math + + tokens = name_with_suffix.split("+") + name = tokens[0] + base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 + + if name == "cosine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] + elif name == "sine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] + elif name == "linear": + return [i / (max_len - 1) + base_lr for i in range(max_len)] + elif name == "reverse_linear": + return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] + elif name == "zeros": + return [0.0 + base_lr] * max_len + else: + print( + "Unknown lr_weight argument %s is used. Valid arguments: \n\tcosine, sine, linear, reverse_linear, zeros" + % (name) + ) + return None + + if type(down_lr_weight) == str: + down_lr_weight = get_list(down_lr_weight) + if type(up_lr_weight) == str: + up_lr_weight = get_list(up_lr_weight) + + if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): + print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + up_lr_weight = up_lr_weight[:max_len] + down_lr_weight = down_lr_weight[:max_len] + + if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): + print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + + if down_lr_weight != None and len(down_lr_weight) < max_len: + down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) + if up_lr_weight != None and len(up_lr_weight) < max_len: + up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) + + if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): + print("apply block learning rate") + if down_lr_weight != None: + down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] + print("down_lr_weight (shallower -> deeper):", down_lr_weight) + else: + print("down_lr_weight: all 1.0") + + if mid_lr_weight != None: + mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + print("mid_lr_weight:", mid_lr_weight) + else: + print("mid_lr_weight: 1.0") + + if up_lr_weight != None: + up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] + print("up_lr_weight (deeper -> shallower):", up_lr_weight) + else: + print("up_lr_weight: all 1.0") + + return down_lr_weight, mid_lr_weight, up_lr_weight + +def remove_block_dims_and_alphas( + block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight +): + if down_lr_weight != None: + for i, lr in enumerate(down_lr_weight): + if lr == 0: + block_dims[i] = 0 + if conv_block_dims is not None: + conv_block_dims[i] = 0 + if mid_lr_weight != None: + if mid_lr_weight == 0: + block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 + if conv_block_dims is not None: + conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 + if up_lr_weight != None: + for i, lr in enumerate(up_lr_weight): + if lr == 0: + block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 + if conv_block_dims is not None: + conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 外部から呼び出す可能性を考慮しておく +def get_block_index(lora_name: str) -> int: + block_idx = -1 # invalid lora name + + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 + + if g[0] == "down": + block_idx = 1 + idx # There is no LoRA corresponding to 0. + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + + return block_idx + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # print(lora_name, value.size(), dim) + + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha = modules_dim[key] + + module_class = LoRAInferenceModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + NUM_OF_BLOCKS = 12 # Number of up and down layers equivalent to full model + + # is it possible to apply conv_in and conv_out + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + def __init__( + self, + text_encoder, + unet, + multiplier=1.0, + lora_dim=4, + alpha=1, + conv_lora_dim=None, + conv_alpha=None, + block_dims=None, + block_alphas=None, + conv_block_dims=None, + conv_block_alphas=None, + modules_dim=None, + modules_alpha=None, + module_class=LoRAModule, + varbose=False, + ) -> None: + """ + LoRA network: There are a lot of arguments, but the pattern is as follows + 1. Specify lora_dim and alpha + 2. Specify lora_dim, alpha, conv_lora_dim, conv_alpha + 3. Specify block_dims and block_alphas: Not applicable to Conv2d3x3 + 4. Specify block_dims, block_alphas, conv_block_dims, conv_block_alphas: Also applies to Conv2d3x3 + 5. Specify modules_dim and modules_alpha (for inference) + """ + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + + if modules_dim is not None: + print(f"create LoRA network from weights") + elif block_dims is not None: + print(f"create LoRA network from block_dims") + print(f"block_dims: {block_dims}") + print(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + print(f"conv_block_dims: {conv_block_dims}") + print(f"conv_block_alphas: {conv_block_alphas}") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + if self.conv_lora_dim is not None: + print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + + # create module instances + def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: + prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + if modules_dim is not None: + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif is_unet and block_dims is not None: + block_idx = get_block_index(lora_name) + if is_linear or is_conv2d_1x1: + dim = block_dims[block_idx] + alpha = block_alphas[block_idx] + elif conv_block_dims is not None: + dim = conv_block_dims[block_idx] + alpha = conv_block_alphas[block_idx] + else: + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue + + lora = module_class(lora_name, child_module, self.multiplier, dim, alpha) + loras.append(lora) + return loras, skipped + + self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras, skipped_un = create_modules(True, unet, target_modules) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + print( + f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped" + ) + for name in skipped: + print(f"\t{name}") + + self.up_lr_weight: List[float] = None + self.down_lr_weight: List[float] = None + self.mid_lr_weight: float = None + self.block_lr = False + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + print(f"weights are merged") + + # Define the multiplier for the learning rate per layer for layered learning rate + def set_block_lr_weight( + self, + up_lr_weight: List[float] = None, + mid_lr_weight: float = None, + down_lr_weight: List[float] = None, + ): + self.block_lr = True + self.down_lr_weight = down_lr_weight + self.mid_lr_weight = mid_lr_weight + self.up_lr_weight = up_lr_weight + + def get_lr_weight(self, lora: LoRAModule) -> float: + lr_weight = 1.0 + block_idx = get_block_index(lora.lora_name) + if block_idx < 0: + return lr_weight + + if block_idx < LoRANetwork.NUM_OF_BLOCKS: + if self.down_lr_weight != None: + lr_weight = self.down_lr_weight[block_idx] + elif block_idx == LoRANetwork.NUM_OF_BLOCKS: + if self.mid_lr_weight != None: + lr_weight = self.mid_lr_weight + elif block_idx > LoRANetwork.NUM_OF_BLOCKS: + if self.up_lr_weight != None: + lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1] + + return lr_weight + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + if self.unet_loras: + if self.block_lr: + # classify lora for each block so that its learning rate can be graphed. + block_idx_to_lora = {} + for lora in self.unet_loras: + idx = get_block_index(lora.lora_name) + if idx not in block_idx_to_lora: + block_idx_to_lora[idx] = [] + block_idx_to_lora[idx].append(lora) + + # Set parameters for each block + for idx, block_loras in block_idx_to_lora.items(): + param_data = {"params": enumerate_params(block_loras)} + + if unet_lr is not None: + param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + elif default_lr is not None: + param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + all_params.append(param_data) + + else: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + import model.train as train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + # mask is a tensor with values from 0 to 1 + def set_region(self, sub_prompt_index, is_last_network, mask): + if mask.max() == 0: + mask = torch.ones_like(mask) + + self.mask = mask + self.sub_prompt_index = sub_prompt_index + self.is_last_network = is_last_network + + for lora in self.text_encoder_loras + self.unet_loras: + lora.set_network(self) + + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + self.batch_size = batch_size + self.num_sub_prompts = num_sub_prompts + self.current_size = (height, width) + self.shared = shared + + # create masks + mask = self.mask + mask_dic = {} + mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w + ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight + dtype = ref_weight.dtype + device = ref_weight.device + + def resize_add(mh, mw): + # print(mh, mw, mh * mw) + m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 + m = m.to(device, dtype=dtype) + mask_dic[mh * mw] = m + + h = height // 8 + w = width // 8 + for _ in range(4): + resize_add(h, w) + if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 + resize_add(h + h % 2, w + w % 2) + h = (h + 1) // 2 + w = (w + 1) // 2 + + self.mask_dic = mask_dic diff --git a/SystemCode/src/main/model/pipeline.py b/SystemCode/src/main/model/pipeline.py new file mode 100644 index 0000000..7d7a24f --- /dev/null +++ b/SystemCode/src/main/model/pipeline.py @@ -0,0 +1,1179 @@ +from networks.lora import LoRANetwork +# import tools.original_control_net as original_control_net +# from tools.original_control_net import ControlNetInfo + +import inspect +from typing import Callable, List, Optional, Union +from einops import rearrange +from tqdm import tqdm +import torchvision +from torchvision import transforms +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict + +from diffusers import ( + AutoencoderKL, + SchedulerMixin, + UNet2DConditionModel +) +from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig + +# CLIP +FEATURE_EXTRACTOR_SIZE = (224, 224) +FEATURE_EXTRACTOR_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073] +FEATURE_EXTRACTOR_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711] + +# VGG +VGG16_IMAGE_MEAN = [0.485, 0.456, 0.406] +VGG16_IMAGE_STD = [0.229, 0.224, 0.225] +VGG16_INPUT_RESIZE_DIV = 4 + +""" +Script has been modified from the base template provided by +Hugging Face Diffusers library: https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +""" +class PipelineLike: + r""" + Pipeline for text-to-image generation using Stable Diffusion. + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`] + """ + + def __init__( + self, + device, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + clip_skip: int, + clip_model: CLIPModel, + clip_guidance_scale: float, + clip_image_guidance_scale: float, + vgg16_model: torchvision.models.VGG, + vgg16_guidance_scale: float, + vgg16_layer_no: int + ): + super().__init__() + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.unet = unet + self.scheduler = scheduler + self.safety_checker = None + + # CLIP guidance + self.clip_guidance_scale = clip_guidance_scale + self.clip_image_guidance_scale = clip_image_guidance_scale + self.clip_model = clip_model + self.normalize = transforms.Normalize(mean=FEATURE_EXTRACTOR_IMAGE_MEAN, std=FEATURE_EXTRACTOR_IMAGE_STD) + self.make_cutouts = MakeCutouts(FEATURE_EXTRACTOR_SIZE) + + # VGG16 guidance + self.vgg16_guidance_scale = vgg16_guidance_scale + if self.vgg16_guidance_scale > 0.0: + return_layers = {f"{vgg16_layer_no}": "feat"} + self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter( + vgg16_model.features, return_layers=return_layers + ) + self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) + + # ControlNet + self.control_nets: List[ControlNetInfo] = [] + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + # region xformersとか使う部分:独自に書き換えるので関係なし + + def enable_xformers_memory_efficient_attention(self): + r""" + Currently available for Linux only. + Enable memory efficient attention as implemented in xformers. + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + # accelerateが必要になるのでとりあえず省略 + raise NotImplementedError("cpu_offload is omitted.") + # if is_accelerate_available(): + # from accelerate import cpu_offload + # else: + # raise ImportError("Please install accelerate via `pip install accelerate`") + + # device = self.device + + # for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + # if cpu_offloaded_model is not None: + # cpu_offload(cpu_offloaded_model, device) + + # endregion + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_prompts=None, + clip_guide_images=None, + networks: Optional[List[LoRANetwork]] = None, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + num_images_per_prompt = 1 # fixed + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + reginonal_network = " AND " in prompt[0] + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + if not self.token_replacements_XTI: + text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) + + if self.token_replacements_XTI: + text_embeddings_concat = [] + for layer in [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ]: + text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + layer=layer, + **kwargs, + ) + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings])) + else: + text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])) + text_embeddings = torch.stack(text_embeddings_concat) + else: + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + # CLIP guidanceで使用するembeddingsを取得する + if self.clip_guidance_scale > 0: + clip_text_input = prompt_tokens + if clip_text_input.shape[1] > self.tokenizer.model_max_length: + # TODO 75文字を超えたら警告を出す? + print("trim text input", clip_text_input.shape) + clip_text_input = torch.cat( + [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 + ) + print("trimmed", clip_text_input.shape) + + for i, clip_prompt in enumerate(clip_prompts): + if clip_prompt is not None: # clip_promptがあれば上書きする + clip_text_input[i] = self.tokenizer( + clip_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.to(self.device) + + text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) + text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK + + if ( + self.clip_image_guidance_scale > 0 + or self.vgg16_guidance_scale > 0 + and clip_guide_images is not None + or self.control_nets + ): + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + + if self.clip_image_guidance_scale > 0: + clip_guide_images = [preprocess_guide_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images, dim=0) + + clip_guide_images = self.normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) + image_embeddings_clip = self.clip_model.get_image_features(clip_guide_images) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + if len(image_embeddings_clip) == 1: + image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1)) + elif self.vgg16_guidance_scale > 0: + size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?) + clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images, dim=0) + + clip_guide_images = self.vgg16_normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) + image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)["feat"] + if len(image_embeddings_vgg16) == 1: + image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1)) + else: + # ControlNetのhintにguide imageを流用する + # 前処理はControlNet側で行う + pass + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + init_latent_dist = self.vae.encode( + init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = 0.18215 * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if self.control_nets: + if reginonal_network: + num_sub_and_neg_prompts = len(text_embeddings) // batch_size + text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt + else: + text_emb_last = text_embeddings + noise_pred = original_control_net.call_unet_and_control_net( + i, + num_latent_input, + self.unet, + self.control_nets, + guided_hints, + i / len(timesteps), + latent_model_input, + t, + text_emb_last, + ).sample + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # perform clip guidance + if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: + text_embeddings_for_guidance = ( + text_embeddings.chunk(num_latent_input)[1] if do_classifier_free_guidance else text_embeddings + ) + + if self.clip_guidance_scale > 0: + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + text_embeddings_clip, + self.clip_guidance_scale, + NUM_CUTOUTS, + USE_CUTOUTS, + ) + if self.clip_image_guidance_scale > 0 and clip_guide_images is not None: + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + image_embeddings_clip, + self.clip_image_guidance_scale, + NUM_CUTOUTS, + USE_CUTOUTS, + ) + if self.vgg16_guidance_scale > 0 and clip_guide_images is not None: + noise_pred, latents = self.cond_fn_vgg16( + latents, t, i, text_embeddings_for_guidance, noise_pred, image_embeddings_vgg16, self.vgg16_guidance_scale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return (latents, False) + + latents = 1 / 0.18215 * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode(latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + # if not return_dict: + return (image,) + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) + + def img2img( + self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for image-to-image generation. + Args: + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) + + def inpaint( + self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for inpaint. + Args: + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) + + # CLIP guidance StableDiffusion + # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py + + # バッチを分解して1件ずつ処理する + def cond_fn( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + if len(latents) == 1: + return self.cond_fn1( + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts, + ) + + noise_pred = [] + cond_latents = [] + for i in range(len(latents)): + lat1 = latents[i].unsqueeze(0) + tem1 = text_embeddings[i].unsqueeze(0) + npo1 = noise_pred_original[i].unsqueeze(0) + gem1 = guide_embeddings_clip[i].unsqueeze(0) + npr1, cla1 = self.cond_fn1(lat1, timestep, index, tem1, npo1, gem1, clip_guidance_scale, num_cutouts, use_cutouts) + noise_pred.append(npr1) + cond_latents.append(cla1) + + noise_pred = torch.cat(noise_pred) + cond_latents = torch.cat(cond_latents) + return noise_pred, cond_latents + + @torch.enable_grad() + def cond_fn1( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + latents = latents.detach().requires_grad_() + + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents + + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) + + if use_cutouts: + image = self.make_cutouts(image, num_cutouts) + else: + image = transforms.Resize(FEATURE_EXTRACTOR_SIZE)(image) + image = self.normalize(image).to(latents.dtype) + + image_embeddings_clip = self.clip_model.get_image_features(image) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + + if use_cutouts: + dists = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip) + dists = dists.view([num_cutouts, sample.shape[0], -1]) + loss = dists.sum(2).mean(0).sum() * clip_guidance_scale + else: + # バッチサイズが複数だと正しく動くかわからない + loss = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip).mean() * clip_guidance_scale + + grads = -torch.autograd.grad(loss, latents)[0] + + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents + + # バッチを分解して一件ずつ処理する + def cond_fn_vgg16(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): + if len(latents) == 1: + return self.cond_fn_vgg16_b1( + latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale + ) + + noise_pred = [] + cond_latents = [] + for i in range(len(latents)): + lat1 = latents[i].unsqueeze(0) + tem1 = text_embeddings[i].unsqueeze(0) + npo1 = noise_pred_original[i].unsqueeze(0) + gem1 = guide_embeddings[i].unsqueeze(0) + npr1, cla1 = self.cond_fn_vgg16_b1(lat1, timestep, index, tem1, npo1, gem1, guidance_scale) + noise_pred.append(npr1) + cond_latents.append(cla1) + + noise_pred = torch.cat(noise_pred) + cond_latents = torch.cat(cond_latents) + return noise_pred, cond_latents + + # 1件だけ処理する + @torch.enable_grad() + def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): + latents = latents.detach().requires_grad_() + + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents + + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) + image = transforms.Resize((image.shape[-2] // VGG16_INPUT_RESIZE_DIV, image.shape[-1] // VGG16_INPUT_RESIZE_DIV))(image) + image = self.vgg16_normalize(image).to(latents.dtype) + + image_embeddings = self.vgg16_feat_model(image)["feat"] + + # バッチサイズが複数だと正しく動くかわからない + loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので + + grads = -torch.autograd.grad(loss, latents)[0] + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents \ No newline at end of file diff --git a/SystemCode/src/main/model/tools/original_control_net.py b/SystemCode/src/main/model/tools/original_control_net.py new file mode 100644 index 0000000..015a52b --- /dev/null +++ b/SystemCode/src/main/model/tools/original_control_net.py @@ -0,0 +1,294 @@ +from typing import List, NamedTuple, Any +import numpy as np +import cv2 +import torch +from safetensors.torch import load_file + +from diffusers import UNet2DConditionModel +from diffusers.models.unet_2d_condition import UNet2DConditionOutput + +import model.model as model_util + +""" +From Hugging Face Diffusers +https://github.com/huggingface/diffusers +""" +class ControlNetInfo(NamedTuple): + unet: Any + net: Any + prep: Any + weight: float + ratio: float + + +class ControlNet(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + # make control model + self.control_model = torch.nn.Module() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] + zero_convs = torch.nn.ModuleList() + for i, dim in enumerate(dims): + sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) + zero_convs.append(sub_list) + self.control_model.add_module("zero_convs", zero_convs) + + middle_block_out = torch.nn.Conv2d(1280, 1280, 1) + self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) + + dims = [16, 16, 32, 32, 96, 96, 256, 320] + strides = [1, 1, 2, 1, 2, 1, 2, 1] + prev_dim = 3 + input_hint_block = torch.nn.Sequential() + for i, (dim, stride) in enumerate(zip(dims, strides)): + input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) + if i < len(dims) - 1: + input_hint_block.append(torch.nn.SiLU()) + prev_dim = dim + self.control_model.add_module("input_hint_block", input_hint_block) + + +def load_control_net(v2, unet, model): + device = unet.device + + # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む + # state dictを読み込む + print(f"ControlNet: loading control SD model : {model}") + + if model_util.is_safetensors(model): + ctrl_sd_sd = load_file(model) + else: + ctrl_sd_sd = torch.load(model, map_location='cpu') + ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) + + # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む + is_difference = "difference" in ctrl_sd_sd + print("ControlNet: loading difference") + + # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく + # またTransfer Controlの元weightとなる + ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) + + # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける + for key in list(ctrl_unet_sd_sd.keys()): + ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() + + zero_conv_sd = {} + for key in list(ctrl_sd_sd.keys()): + if key.startswith("control_"): + unet_key = "model.diffusion_" + key[len("control_"):] + if unet_key not in ctrl_unet_sd_sd: # zero conv + zero_conv_sd[key] = ctrl_sd_sd[key] + continue + if is_difference: # Transfer Control + ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) + else: + ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) + + unet_config = model_util.create_unet_diffusers_config(v2) + ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers version of ControlNet state dict + + # ControlNet's U-Net + ctrl_unet = UNet2DConditionModel(**unet_config) + info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) + print("ControlNet: loading Control U-Net:", info) + + # Create a ControlNet other than U-Net + ctrl_net = ControlNet() + info = ctrl_net.load_state_dict(zero_conv_sd) + print("ControlNet: loading ControlNet:", info) + + ctrl_unet.to(unet.device, dtype=unet.dtype) + ctrl_net.to(unet.device, dtype=unet.dtype) + return ctrl_unet, ctrl_net + + +def load_preprocess(prep_type: str): + if prep_type is None or prep_type.lower() == "none": + return None + + if prep_type.startswith("canny"): + args = prep_type.split("_") + th1 = int(args[1]) if len(args) >= 2 else 63 + th2 = int(args[2]) if len(args) >= 3 else 191 + + def canny(img): + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + return cv2.Canny(img, th1, th2) + return canny + + print("Unsupported prep type:", prep_type) + return None + + +def preprocess_ctrl_net_hint_image(image): + image = np.array(image).astype(np.float32) / 255.0 + image = image[:, :, ::-1].copy() # rgb to bgr + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 + + +def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints): + guided_hints = [] + for i, cnet_info in enumerate(control_nets): + # The hints should be arranged as cnet1 in the first image, cnet2 in the first image, cnet3 in the first image, cnet1 in the second image, cnet2 in the second image, etc. + b_hints = [] + if len(hints) == 1: + # Use all the same images as hints + hint = hints[0] + if cnet_info.prep is not None: + hint = cnet_info.prep(hint) + hint = preprocess_ctrl_net_hint_image(hint) + b_hints = [hint for _ in range(b_size)] + else: + for bi in range(b_size): + hint = hints[(bi * len(control_nets) + i) % len(hints)] + if cnet_info.prep is not None: + hint = cnet_info.prep(hint) + hint = preprocess_ctrl_net_hint_image(hint) + b_hints.append(hint) + b_hints = torch.cat(b_hints, dim=0) + b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype) + + guided_hint = cnet_info.net.control_model.input_hint_block(b_hints) + guided_hints.append(guided_hint) + return guided_hints + + +def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states): + # ControlNet + # For multiple ControlNets, apply outputs alternately instead of merging + cnet_cnt = len(control_nets) + cnet_idx = step % cnet_cnt + cnet_info = control_nets[cnet_idx] + + if cnet_info.ratio < current_ratio: + return original_unet(sample, timestep, encoder_hidden_states) + + guided_hint = guided_hints[cnet_idx] + guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) + outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) + outs = [o * cnet_info.weight for o in outs] + + # U-Net + return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states) + + +def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states): + # copy from UNet2DConditionModel + default_overall_up_factor = 2**unet.num_upsamplers + + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + print("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if unet.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = unet.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=unet.dtype) + emb = unet.time_embedding(t_emb) + + outs = [] # output of ControlNet + zc_idx = 0 + + # 2. pre-process + sample = unet.conv_in(sample) + if is_control_net: + sample += guided_hint + outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states)) + zc_idx += 1 + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in unet.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_control_net: + for rs in res_samples: + outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) + zc_idx += 1 + + down_block_res_samples += res_samples + + # 4. mid + sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + if is_control_net: + outs.append(control_net.control_model.middle_block_out[0](sample)) + return outs + + if not is_control_net: + sample += ctrl_outs.pop() + + # 5. up + for i, upsample_block in enumerate(unet.up_blocks): + is_final_block = i == len(unet.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if not is_control_net and len(ctrl_outs) > 0: + res_samples = list(res_samples) + apply_ctrl_outs = ctrl_outs[-len(res_samples):] + ctrl_outs = ctrl_outs[:-len(res_samples)] + for j in range(len(res_samples)): + res_samples[j] = res_samples[j] + apply_ctrl_outs[j] + res_samples = tuple(res_samples) + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = unet.conv_norm_out(sample) + sample = unet.conv_act(sample) + sample = unet.conv_out(sample) + + return UNet2DConditionOutput(sample=sample) diff --git a/SystemCode/src/main/model/train.py b/SystemCode/src/main/model/train.py new file mode 100644 index 0000000..81f9370 --- /dev/null +++ b/SystemCode/src/main/model/train.py @@ -0,0 +1,3057 @@ +# common functions for training + +import argparse +import ast +import asyncio +import importlib +import json +import pathlib +import re +import shutil +import time +from typing import ( + Dict, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) +from accelerate import Accelerator +import glob +import math +import os +import random +import hashlib +import subprocess +from io import BytesIO +import toml + +from tqdm import tqdm +import torch +from torch.optim import Optimizer +from torchvision import transforms +from transformers import CLIPTokenizer +import transformers +import diffusers +from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from diffusers import ( + StableDiffusionPipeline, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, +) +from huggingface_hub import hf_hub_download +import albumentations as albu +import numpy as np +from PIL import Image +import cv2 +from einops import rearrange +from torch import einsum +import safetensors.torch +import model as model_util + +TOKENIZER_PATH = "openai/clip-vit-large-patch14" + +EPOCH_STATE_NAME = "{}-{:06d}-state" +EPOCH_FILE_NAME = "{}-{:06d}" +EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" +LAST_STATE_NAME = "{}-state" +DEFAULT_EPOCH_NAME = "epoch" +DEFAULT_LAST_OUTPUT_NAME = "last" + +# region dataset + +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] + + +class ImageInfo: + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.caption: str = caption + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: Tuple[int, int] = None + self.resized_size: Tuple[int, int] = None + self.bucket_reso: Tuple[int, int] = None + self.latents: torch.Tensor = None + self.latents_flipped: torch.Tensor = None + self.latents_npz: str = None + self.latents_npz_flipped: str = None + + +class BucketManager: + def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: + self.no_upscale = no_upscale + if max_reso is None: + self.max_reso = None + self.max_area = None + else: + self.max_reso = max_reso + self.max_area = max_reso[0] * max_reso[1] + self.min_size = min_size + self.max_size = max_size + self.reso_steps = reso_steps + + self.resos = [] + self.reso_to_id = {} + self.buckets = [] + + def add_image(self, reso, image): + bucket_id = self.reso_to_id[reso] + self.buckets[bucket_id].append(image) + + def shuffle(self): + for bucket in self.buckets: + random.shuffle(bucket) + + def sort(self): + sorted_resos = self.resos.copy() + sorted_resos.sort() + + sorted_buckets = [] + sorted_reso_to_id = {} + for i, reso in enumerate(sorted_resos): + bucket_id = self.reso_to_id[reso] + sorted_buckets.append(self.buckets[bucket_id]) + sorted_reso_to_id[reso] = i + + self.resos = sorted_resos + self.buckets = sorted_buckets + self.reso_to_id = sorted_reso_to_id + + def make_buckets(self): + resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) + self.set_predefined_resos(resos) + + def set_predefined_resos(self, resos): + self.predefined_resos = resos.copy() + self.predefined_resos_set = set(resos) + self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) + + def add_if_new_reso(self, reso): + if reso not in self.reso_to_id: + bucket_id = len(self.resos) + self.reso_to_id[reso] = bucket_id + self.resos.append(reso) + self.buckets.append([]) + # print(reso, bucket_id, len(self.buckets)) + + def round_to_steps(self, x): + x = int(x + 0.5) + return x - x % self.reso_steps + + def select_bucket(self, image_width, image_height): + aspect_ratio = image_width / image_height + if not self.no_upscale: + # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する + reso = (image_width, image_height) + if reso in self.predefined_resos_set: + pass + else: + ar_errors = self.predefined_aspect_ratios - aspect_ratio + predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの + reso = self.predefined_resos[predefined_bucket_id] + + ar_reso = reso[0] / reso[1] + if aspect_ratio > ar_reso: # 横が長い→縦を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + + resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) + # print("use predef", image_width, image_height, reso, resized_size) + else: + if image_width * image_height > self.max_area: + # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める + resized_width = math.sqrt(self.max_area * aspect_ratio) + resized_height = self.max_area / resized_width + assert abs(resized_width" + + # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ + # 元のbucketingと同じロジック + b_width_rounded = self.round_to_steps(resized_width) + b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) + ar_width_rounded = b_width_rounded / b_height_in_wr + + b_height_rounded = self.round_to_steps(resized_height) + b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) + ar_height_rounded = b_width_in_hr / b_height_rounded + + # print(b_width_rounded, b_height_in_wr, ar_width_rounded) + # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + + if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): + resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) + else: + resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) + # print(resized_size) + else: + resized_size = (image_width, image_height) # リサイズは不要 + + # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) + bucket_width = resized_size[0] - resized_size[0] % self.reso_steps + bucket_height = resized_size[1] - resized_size[1] % self.reso_steps + # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + + reso = (bucket_width, bucket_height) + + self.add_if_new_reso(reso) + + ar_error = (reso[0] / reso[1]) - aspect_ratio + return reso, resized_size, ar_error + + +class BucketBatchIndex(NamedTuple): + bucket_index: int + bucket_batch_size: int + batch_index: int + + +class AugHelper: + def __init__(self): + # prepare all possible augmentators + color_aug_method = albu.OneOf( + [ + albu.HueSaturationValue(8, 0, 0, p=0.5), + albu.RandomGamma((95, 105), p=0.5), + ], + p=0.33, + ) + flip_aug_method = albu.HorizontalFlip(p=0.5) + + # key: (use_color_aug, use_flip_aug) + self.augmentors = { + (True, True): albu.Compose( + [ + color_aug_method, + flip_aug_method, + ], + p=1.0, + ), + (True, False): albu.Compose( + [ + color_aug_method, + ], + p=1.0, + ), + (False, True): albu.Compose( + [ + flip_aug_method, + ], + p=1.0, + ), + (False, False): None, + } + + def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: + return self.augmentors[(use_color_aug, use_flip_aug)] + + +class BaseSubset: + def __init__( + self, + image_dir: Optional[str], + num_repeats: int, + shuffle_caption: bool, + keep_tokens: int, + color_aug: bool, + flip_aug: bool, + face_crop_aug_range: Optional[Tuple[float, float]], + random_crop: bool, + caption_dropout_rate: float, + caption_dropout_every_n_epochs: int, + caption_tag_dropout_rate: float, + token_warmup_min: int, + token_warmup_step: Union[float, int], + ) -> None: + self.image_dir = image_dir + self.num_repeats = num_repeats + self.shuffle_caption = shuffle_caption + self.keep_tokens = keep_tokens + self.color_aug = color_aug + self.flip_aug = flip_aug + self.face_crop_aug_range = face_crop_aug_range + self.random_crop = random_crop + self.caption_dropout_rate = caption_dropout_rate + self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs + self.caption_tag_dropout_rate = caption_tag_dropout_rate + + self.token_warmup_min = token_warmup_min + self.token_warmup_step = token_warmup_step + + self.img_count = 0 + + +class DreamBoothSubset(BaseSubset): + def __init__( + self, + image_dir: str, + is_reg: bool, + class_tokens: Optional[str], + caption_extension: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) -> None: + assert image_dir is not None, + + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) + + self.is_reg = is_reg + self.class_tokens = class_tokens + self.caption_extension = caption_extension + + def __eq__(self, other) -> bool: + if not isinstance(other, DreamBoothSubset): + return NotImplemented + return self.image_dir == other.image_dir + + +class FineTuningSubset(BaseSubset): + def __init__( + self, + image_dir, + metadata_file: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) -> None: + assert metadata_file is not None, + + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) + + self.metadata_file = metadata_file + + def __eq__(self, other) -> bool: + if not isinstance(other, FineTuningSubset): + return NotImplemented + return self.metadata_file == other.metadata_file + + +class BaseDataset(torch.utils.data.Dataset): + def __init__( + self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.max_token_length = max_token_length + # width/height is used when enable_bucket==False + self.width, self.height = (None, None) if resolution is None else resolution + self.debug_dataset = debug_dataset + + self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] + + self.token_padding_disabled = False + self.tag_frequency = {} + self.XTI_layers = None + self.token_strings = None + + self.enable_bucket = False + self.bucket_manager: BucketManager = None # not initialized + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None + self.bucket_no_upscale = None + self.bucket_info = None # for metadata + + self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + + self.current_epoch: int = 0 + + self.current_step: int = 0 + self.max_train_steps: int = 0 + self.seed: int = 0 + + # augmentation + self.aug_helper = AugHelper() + + self.image_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.image_data: Dict[str, ImageInfo] = {} + self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} + + self.replacements = {} + + def set_seed(self, seed): + self.seed = seed + + def set_current_epoch(self, epoch): + if not self.current_epoch == epoch: + self.shuffle_buckets() + self.current_epoch = epoch + + def set_current_step(self, step): + self.current_step = step + + def set_max_train_steps(self, max_train_steps): + self.max_train_steps = max_train_steps + + def set_tag_frequency(self, dir_name, captions): + frequency_for_dir = self.tag_frequency.get(dir_name, {}) + self.tag_frequency[dir_name] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + tag = tag.strip() + if tag: + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 + + def disable_token_padding(self): + self.token_padding_disabled = True + + # def enable_XTI(self, layers=None, token_strings=None): + # self.XTI_layers = layers + # self.token_strings = token_strings + + def add_replacement(self, str_from, str_to): + self.replacements[str_from] = str_to + + def process_caption(self, subset: BaseSubset, caption): + is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate + is_drop_out = ( + is_drop_out + or subset.caption_dropout_every_n_epochs > 0 + and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 + ) + + if is_drop_out: + caption = "" + else: + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: + tokens = [t.strip() for t in caption.strip().split(",")] + if subset.token_warmup_step < 1: + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = ( + math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + + subset.token_warmup_min + ) + tokens = tokens[:tokens_len] + + def dropout_tags(tokens): + if subset.caption_tag_dropout_rate <= 0: + return tokens + l = [] + for token in tokens: + if random.random() >= subset.caption_tag_dropout_rate: + l.append(token) + return l + + fixed_tokens = [] + flex_tokens = tokens[:] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[: subset.keep_tokens] + flex_tokens = tokens[subset.keep_tokens :] + + if subset.shuffle_caption: + random.shuffle(flex_tokens) + + flex_tokens = dropout_tags(flex_tokens) + + caption = ", ".join(fixed_tokens + flex_tokens) + + # textual inversion correspondence + for str_from, str_to in self.replacements.items(): + if str_from == "": + # replace all + if type(str_to) == list: + caption = random.choice(str_to) + else: + caption = str_to + else: + caption = caption.replace(str_from, str_to) + + return caption + + def get_input_ids(self, caption): + input_ids = self.tokenizer( + caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt" + ).input_ids + + if self.tokenizer_max_length > self.tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # If tokens are of length 77 or above, the total is 227 with " .... " + for i in range( + 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 + ): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + for i in range( + 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 + ): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # # If the end is or , do nothing + # If the ending is x , change the ending to (if it is x , there will be no change) + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # If the beginning is ..., change it to ... + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + def register_image(self, info: ImageInfo, subset: BaseSubset): + self.image_data[info.image_key] = info + self.image_to_subset[info.image_key] = subset + + def make_buckets(self): + """ + Required to call even if bucketing is not performed (create only one bucket) + min_size and max_size are ignored when enable_bucket is False + """ + print("loading image sizes.") + for info in tqdm(self.image_data.values()): + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) + + if self.enable_bucket: + print("make buckets") + else: + print("prepare dataset") + + # Create a bucket and distribute images to the bucket + if self.enable_bucket: + if self.bucket_manager is None: + # In case of fine tuning, if there is a definition in metadata, it has already been initialized. + self.bucket_manager = BucketManager( + self.bucket_no_upscale, + (self.width, self.height), + self.min_bucket_reso, + self.max_bucket_reso, + self.bucket_reso_steps, + ) + if not self.bucket_no_upscale: + self.bucket_manager.make_buckets() + else: + print( + "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically" + ) + + img_ar_errors = [] + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( + image_width, image_height + ) + + # print(image_info.image_key, image_info.bucket_reso) + img_ar_errors.append(abs(ar_error)) + + self.bucket_manager.sort() + else: + self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) + self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # Only one fixed size bucket + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) + + for image_info in self.image_data.values(): + for _ in range(image_info.num_repeats): + self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) + + # Display and store bucket information + if self.enable_bucket: + self.bucket_info = {"buckets": {}} + print("number of images (including repeats)") + for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): + count = len(bucket) + if count > 0: + self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} + print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + + img_ar_errors = np.array(img_ar_errors) + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + self.bucket_info["mean_img_ar_error"] = mean_img_ar_error + print(f"mean ar error (without repeats): {mean_img_ar_error}") + + # Create an index for data reference. This index is used to shuffle the dataset + self.buckets_indices: List(BucketBatchIndex) = [] + for bucket_index, bucket in enumerate(self.bucket_manager.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) + # num_of_image_types = len(set(bucket)) + # bucket_batch_size = min(self.batch_size, num_of_image_types) + # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) + # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # for batch_index in range(batch_count): + # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + def shuffle_buckets(self): + # set random seed for this epoch + random.seed(self.seed + self.current_epoch) + + random.shuffle(self.buckets_indices) + self.bucket_manager.shuffle() + + def load_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): + image_height, image_width = image.shape[0:2] + + if image_width != resized_size[0] or image_height != resized_size[1]: + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) + + image_height, image_width = image.shape[0:2] + if image_width > reso[0]: + trim_size = image_width - reso[0] + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) + # print("w", trim_size, p) + image = image[:, p : p + reso[0]] + if image_height > reso[1]: + trim_size = image_height - reso[1] + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) + # print("h", trim_size, p) + image = image[p : p + reso[1]] + + assert ( + image.shape[0] == reso[1] and image.shape[1] == reso[0] + ), f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image + + def is_latent_cacheable(self): + return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) + + def cache_latents(self, vae, vae_batch_size=1): + print("caching latents.") + + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution + batches = [] + batch = [] + for info in image_infos: + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: + info.latents = self.load_latents_from_npz(info, False) + info.latents = torch.FloatTensor(info.latents) + info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + if info.latents_flipped is not None: + info.latents_flipped = torch.FloatTensor(info.latents_flipped) + continue + + # if last member of batch has different resolution, flush the batch + if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: + batches.append(batch) + batch = [] + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= vae_batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # iterate batches + for batch in tqdm(batches, smoothing=1, total=len(batches)): + images = [] + for info in batch: + image = self.load_image(info.absolute_path) + image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) + image = self.image_transforms(image) + images.append(image) + + img_tensors = torch.stack(images, dim=0) + img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) + + latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + for info, latent in zip(batch, latents): + info.latents = latent + + if subset.flip_aug: + img_tensors = torch.flip(img_tensors, dims=[3]) + latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + for info, latent in zip(batch, latents): + info.latents_flipped = latent + + def get_image_size(self, image_path): + image = Image.open(image_path) + return image.size + + def load_image_with_face_info(self, subset: BaseSubset, image_path: str): + img = self.load_image(image_path) + + face_cx = face_cy = face_w = face_h = 0 + if subset.face_crop_aug_range is not None: + tokens = os.path.splitext(os.path.basename(image_path))[0].split("_") + if len(tokens) >= 5: + face_cx = int(tokens[-4]) + face_cy = int(tokens[-3]) + face_w = int(tokens[-2]) + face_h = int(tokens[-1]) + + return img, face_cx, face_cy, face_w, face_h + + # いい感じに切り出す + def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): + height, width = image.shape[0:2] + if height == self.height and width == self.width: + return image + + # 画像サイズはsizeより大きいのでリサイズする + face_size = max(face_w, face_h) + min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) + min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ + if min_scale >= max_scale: # range指定がmin==max + scale = min_scale + else: + scale = random.uniform(min_scale, max_scale) + + nh = int(height * scale + 0.5) + nw = int(width * scale + 0.5) + assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + face_cx = int(face_cx * scale + 0.5) + face_cy = int(face_cy * scale + 0.5) + height, width = nh, nw + + # 顔を中心として448*640とかへ切り出す + for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): + p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 + + if subset.random_crop: + # 背景も含めるために顔を中心に置く確率を高めつつずらす + range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう + p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 + else: + # range指定があるときのみ、すこしだけランダムに(わりと適当) + if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: + if face_size > self.size // 10 and face_size >= 40: + p1 = p1 + random.randint(-face_size // 20, +face_size // 20) + + p1 = max(0, min(p1, length - target_size)) + + if axis == 0: + image = image[p1 : p1 + target_size, :] + else: + image = image[:, p1 : p1 + target_size] + + return image + + def load_latents_from_npz(self, image_info: ImageInfo, flipped): + npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + if npz_file is None: + return None + return np.load(npz_file)["arr_0"] + + def __len__(self): + return self._length + + def __getitem__(self, index): + bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] + bucket_batch_size = self.buckets_indices[index].bucket_batch_size + image_index = self.buckets_indices[index].batch_index * bucket_batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + + for image_key in bucket[image_index : image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + + # image/latentsを処理する + if image_info.latents is not None: + latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: + latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) + latents = torch.FloatTensor(latents) + image = None + else: + # Load image and crop if necessary + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + else: + if face_cx > 0: + # face location information + img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert ( + subset.random_crop + ), f"image too large, but cropping and bucketing are disabled" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p : p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p : p + self.width] + + im_h, im_w = img.shape[0:2] + assert ( + im_h == self.height and im_w == self.width + ), f"image size is small" + + # augmentation + aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) + if aug is not None: + img = aug(image=img)["image"] + + latents = None + image = self.image_transforms(img) # -1.0~1.0 torch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(subset, image_info.caption) + # if self.XTI_layers: + # caption_layer = [] + # for layer in self.XTI_layers: + # token_strings_from = " ".join(self.token_strings) + # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + # caption_ = caption.replace(token_strings_from, token_strings_to) + # caption_layer.append(caption_) + # captions.append(caption_layer) + # else: + captions.append(caption) + + example = {} + example["loss_weights"] = torch.FloatTensor(loss_weights) + + if self.token_padding_disabled: + # padding=True means pad in the batch + example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + else: + # batch processing seems to be good + example["input_ids"] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example["images"] = images + + example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None + + if self.debug_dataset: + example["image_keys"] = bucket[image_index : image_index + self.batch_size] + example["captions"] = captions + return example + + +class DreamBoothDataset(BaseDataset): + def __init__( + self, + subsets: Sequence[DreamBoothSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + prior_loss_weight: float, + debug_dataset, + ) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + + assert resolution is not None, f"resolution is required" + + self.batch_size = batch_size + self.size = min(self.width, self.height) + self.prior_loss_weight = prior_loss_weight + self.latents_cache = None + + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution" + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None + self.bucket_no_upscale = False + + def read_caption(img_path, caption_extension): + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + print(f"illegal char in file (not UTF-8)") + raise e + assert len(lines) > 0, f"caption file is empty" + caption = lines[0].strip() + break + return caption + + def load_dreambooth_dir(subset: DreamBoothSubset): + if not os.path.isdir(subset.image_dir): + print(f"not directory: {subset.image_dir}") + return [], [] + + img_paths = glob_images(subset.image_dir, "*") + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + + # Read captions + captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension) + if cap_for_img is None and subset.class_tokens is None: + print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") + captions.append("") + else: + captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) + + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) + + return img_paths, captions + + print("prepare images.") + num_train_images = 0 + num_reg_images = 0 + reg_infos: List[ImageInfo] = [] + for subset in subsets: + if subset.num_repeats < 1: + print( + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1" + ) + continue + + if subset in self.subsets: + print( + f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one" + ) + continue + + img_paths, captions = load_dreambooth_dir(subset) + if len(img_paths) < 1: + print(f"ignore subset with image_dir='{subset.image_dir}': no images found") + continue + + if subset.is_reg: + num_reg_images += subset.num_repeats * len(img_paths) + else: + num_train_images += subset.num_repeats * len(img_paths) + + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if subset.is_reg: + reg_infos.append(info) + else: + self.register_image(info, subset) + + subset.img_count = len(img_paths) + self.subsets.append(subset) + + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + print(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + print("some of reg images are not used") + + if num_reg_images == 0: + print("no regularization images") + else: + # Calculate num_repeats + n = 0 + first_loop = True + while n < num_train_images: + for info in reg_infos: + if first_loop: + self.register_image(info, subset) + n += info.num_repeats + else: + info.num_repeats += 1 # rewrite registered info + n += 1 + if n >= num_train_images: + break + first_loop = False + + self.num_reg_images = num_reg_images + + +class FineTuningDataset(BaseDataset): + def __init__( + self, + subsets: Sequence[FineTuningSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + debug_dataset, + ) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + + self.batch_size = batch_size + + self.num_train_images = 0 + self.num_reg_images = 0 + + for subset in subsets: + if subset.num_repeats < 1: + print( + f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1" + ) + continue + + if subset in self.subsets: + print( + f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one" + ) + continue + + if os.path.exists(subset.metadata_file): + print(f"loading existing metadata: {subset.metadata_file}") + with open(subset.metadata_file, "rt", encoding="utf-8") as f: + metadata = json.load(f) + else: + raise ValueError(f"no metadata") + + if len(metadata) < 1: + print(f"ignore subset with '{subset.metadata_file}': no image entries found") + continue + + tags_list = [] + for image_key, img_md in metadata.items(): + if os.path.exists(image_key): + abs_path = image_key + elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"): + abs_path = os.path.splitext(image_key)[0] + ".npz" + else: + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + else: + abs_path = glob_images(subset.image_dir, image_key) + assert len(abs_path) >= 1, f"no image" + abs_path = abs_path[0] + + caption = img_md.get("caption") + tags = img_md.get("tags") + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ", " + tags + tags_list.append(tags) + + if caption is None: + caption = "" + + image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) + image_info.image_size = img_md.get("train_resolution") + + if not subset.color_aug and not subset.random_crop: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + + self.register_image(image_info, subset) + + self.num_train_images += len(metadata) * subset.num_repeats + + # TODO do not record tag freq when no tag + self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) + subset.img_count = len(metadata) + self.subsets.append(subset) + + # check existence of all npz files + use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) + if use_npz_latents: + flip_aug_in_subset = False + npz_any = False + npz_all = True + + for image_info in self.image_data.values(): + subset = self.image_to_subset[image_info.image_key] + + has_npz = image_info.latents_npz is not None + npz_any = npz_any or has_npz + + if subset.flip_aug: + has_npz = has_npz and image_info.latents_npz_flipped is not None + flip_aug_in_subset = True + npz_all = npz_all and has_npz + + if npz_any and not npz_all: + break + + if not npz_any: + use_npz_latents = False + print(f"npz file does not exist. ignore npz files") + elif not npz_all: + use_npz_latents = False + print(f"some of npz file does not exist. ignore npz files") + if flip_aug_in_subset: + print("maybe no flipped files") + # else: + # print("npz files are not used with color_aug and/or random_crop") + + # check min/max bucket size + sizes = set() + resos = set() + for image_info in self.image_data.values(): + if image_info.image_size is None: + sizes = None # not calculated + break + sizes.add(image_info.image_size[0]) + sizes.add(image_info.image_size[1]) + resos.add(tuple(image_info.image_size)) + + if sizes is None: + if use_npz_latents: + use_npz_latents = False + print(f"npz files exist, but no bucket info in metadata. ignore npz files") + + assert ( + resolution is not None + ), "if metadata doesn't have bucket info, resolution is required" + + self.enable_bucket = enable_bucket + if self.enable_bucket: + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + if not enable_bucket: + print("metadata has bucket info, enable bucketing") + print("using bucket info in metadata") + self.enable_bucket = True + + assert ( + not bucket_no_upscale + ), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used" + + # Initialize bucket information, do not recreate it with make_buckets + self.bucket_manager = BucketManager(False, None, None, None, None) + self.bucket_manager.set_predefined_resos(resos) + + # Keep npz information clean + if not use_npz_latents: + for image_info in self.image_data.values(): + image_info.latents_npz = image_info.latents_npz_flipped = None + + def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): + base_name = os.path.splitext(image_key)[0] + npz_file_norm = base_name + ".npz" + + if os.path.exists(npz_file_norm): + # image_key is full path + npz_file_flip = base_name + "_flip.npz" + if not os.path.exists(npz_file_flip): + npz_file_flip = None + return npz_file_norm, npz_file_flip + + # if not full path, check image_dir. if image_dir is None, return None + if subset.image_dir is None: + return None, None + + # image_key is relative path + npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") + npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") + + if not os.path.exists(npz_file_norm): + npz_file_norm = None + npz_file_flip = None + elif not os.path.exists(npz_file_flip): + npz_file_flip = None + + return npz_file_norm, npz_file_flip + + +# behave as Dataset mock +class DatasetGroup(torch.utils.data.ConcatDataset): + def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): + self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] + + super().__init__(datasets) + + self.image_data = {} + self.num_train_images = 0 + self.num_reg_images = 0 + + # simply concat together + # TODO: handling image_data key duplication among dataset + # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset. + for dataset in datasets: + self.image_data.update(dataset.image_data) + self.num_train_images += dataset.num_train_images + self.num_reg_images += dataset.num_reg_images + + def add_replacement(self, str_from, str_to): + for dataset in self.datasets: + dataset.add_replacement(str_from, str_to) + + # def make_buckets(self): + # for dataset in self.datasets: + # dataset.make_buckets() + + def enable_XTI(self, *args, **kwargs): + for dataset in self.datasets: + dataset.enable_XTI(*args, **kwargs) + + def cache_latents(self, vae, vae_batch_size=1): + for i, dataset in enumerate(self.datasets): + print(f"[Dataset {i}]") + dataset.cache_latents(vae, vae_batch_size) + + def is_latent_cacheable(self) -> bool: + return all([dataset.is_latent_cacheable() for dataset in self.datasets]) + + def set_current_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_current_epoch(epoch) + + def set_current_step(self, step): + for dataset in self.datasets: + dataset.set_current_step(step) + + def set_max_train_steps(self, max_train_steps): + for dataset in self.datasets: + dataset.set_max_train_steps(max_train_steps) + + def disable_token_padding(self): + for dataset in self.datasets: + dataset.disable_token_padding() + + +def debug_dataset(train_dataset, show_input_ids=False): + print(f"Total dataset length (steps)") + print("`S` for next step, `E` for next epoch no. , Escape for exit.") + + epoch = 1 + while True: + print(f"epoch: {epoch}") + + steps = (epoch - 1) * len(train_dataset) + 1 + indices = list(range(len(train_dataset))) + random.shuffle(indices) + + k = 0 + for i, idx in enumerate(indices): + train_dataset.set_current_epoch(epoch) + train_dataset.set_current_step(steps) + print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + + example = train_dataset[idx] + if example["latents"] is not None: + print(f"sample has latents from npz file: {example['latents'].size()}") + for j, (ik, cap, lw, iid) in enumerate( + zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) + ): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') + if show_input_ids: + print(f"input ids: {iid}") + if example["images"] is not None: + im = example["images"][j] + print(f"image size: {im.size()}") + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + if os.name == "nt": # only windows + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27 or k == ord("s") or k == ord("e"): + break + steps += 1 + + if k == ord("e"): + break + if k == 27 or (example["images"] is None and i >= 8): + k = 27 + break + if k == 27: + break + + epoch += 1 + + +def glob_images(directory, base="*"): + img_paths = [] + for ext in IMAGE_EXTENSIONS: + if base == "*": + img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) + else: + img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) + img_paths = list(set(img_paths)) # 重複を排除 + img_paths.sort() + return img_paths + + +def glob_images_pathlib(dir_path, recursive): + image_paths = [] + if recursive: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.rglob("*" + ext)) + else: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.glob("*" + ext)) + image_paths = list(set(image_paths)) # 重複を排除 + image_paths.sort() + return image_paths + + +# endregion + +# region モジュール入れ替え部 +""" +Module replacement for faster speed +""" + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def model_hash(filename): + """Old model hash used by stable-diffusion-webui""" + try: + with open(filename, "rb") as file: + m = hashlib.sha256() + + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" + + +def calculate_sha256(filename): + """New model hash used by stable-diffusion-webui""" + try: + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + except FileNotFoundError: + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" + + +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +def addnet_hash_safetensors(b): + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def get_git_revision_hash() -> str: + try: + return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip() + except: + return "(unknown)" + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.function.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) + dp = einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() + + +def replace_unet_cross_attn_to_memory_efficient(): + print("Replace CrossAttention.forward to use FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + + context = context if context is not None else x + context = context.to(x.dtype) + + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_flash_attn + + +def replace_unet_cross_attn_to_xformers(): + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers") + + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + + context = default(context, x) + context = context.to(x.dtype) + + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers + + +# endregion + + +# region arguments + + +def add_sd_models_arguments(parser: argparse.ArgumentParser): + # for pretrained models + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training" + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training)", + ) + + +def add_optimizer_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--optimizer_type", + type=str, + default="", + help="Optimizer to use: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor", + ) + + # backward compatibility + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="use 8bit AdamW optimizer (requires bitsandbytes)", + ) + parser.add_argument( + "--use_lion_optimizer", + action="store_true", + help="use Lion optimizer (requires lion-pytorch)", + ) + + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping " + ) + + parser.add_argument( + "--optimizer_args", + type=str, + default=None, + nargs="*", + help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...")', + ) + + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module") + parser.add_argument( + "--lr_scheduler_args", + type=str, + default=None, + nargs="*", + help='additional arguments for scheduler (like "T_max=100")', + ) + + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor", + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0)", + ) + parser.add_argument( + "--lr_scheduler_num_cycles", + type=int, + default=1, + help="Number of restarts for cosine scheduler with restarts", + ) + parser.add_argument( + "--lr_scheduler_power", + type=float, + default=1, + help="Polynomial power for polynomial scheduler", + ) + + +def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): + parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model") + parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file") + parser.add_argument( + "--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload " + ) + parser.add_argument( + "--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload" + ) + parser.add_argument( + "--huggingface_path_in_repo", + type=str, + default=None, + help="huggingface model path to upload files", + ) + parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token") + parser.add_argument( + "--huggingface_repo_visibility", + type=str, + default=None, + help="huggingface repository visibility ('public' for public, 'private' or None for private)", + ) + parser.add_argument( + "--save_state_to_huggingface", action="store_true", help="save state to huggingface" + ) + parser.add_argument( + "--resume_from_huggingface", + action="store_true", + help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})", + ) + parser.add_argument( + "--async_upload", + action="store_true", + help="upload to huggingface asynchronously", + ) + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving", + ) + parser.add_argument( + "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs" + ) + parser.add_argument( + "--save_n_epoch_ratio", + type=int, + default=None, + help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total)", + ) + parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints") + parser.add_argument( + "--save_last_n_epochs_state", + type=int, + default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)", + ) + parser.add_argument( + "--save_state", + action="store_true", + help="save training state additionally (including optimizer states etc.) ", + ) + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training") + + parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training") + parser.add_argument( + "--max_token_length", + type=int, + default=None, + choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225)", + ) + parser.add_argument( + "--mem_eff_attn", + action="store_true", + help="use memory efficient attention for CrossAttention", + ) + parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace" + ) + + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps") + parser.add_argument( + "--max_train_epochs", + type=int, + default=None, + help="training epochs (overrides max_train_steps)", + ) + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading)", + ) + parser.add_argument( + "--persistent_data_loader_workers", + action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory)", + ) + parser.add_argument("--seed", type=int, default=None, help="random seed for training") + parser.add_argument( + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing" + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass", + ) + parser.add_argument( + "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision" + ) + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients") + parser.add_argument( + "--clip_skip", + type=int, + default=None, + help="use output of nth layer from back of text encoder (n>=1)", + ) + parser.add_argument( + "--logging_dir", + type=str, + default=None, + help="enable logging and output TensorBoard log to this directory", + ) + parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory") + parser.add_argument( + "--noise_offset", + type=float, + default=None, + help="enable noise offset with this value (if enabled, around 0.1 is recommended)", + ) + parser.add_argument( + "--lowram", + action="store_true", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle)", + ) + + parser.add_argument( + "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps" + ) + parser.add_argument( + "--sample_every_n_epochs", + type=int, + default=None, + help="generate sample images every N epochs (overwrites n_steps)", + ) + parser.add_argument( + "--sample_prompts", type=str, default=None, help="file for prompts to generate sample images" + ) + parser.add_argument( + "--sample_sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type for sample images", + ) + + parser.add_argument( + "--config_file", + type=str, + default=None, + help="using .toml instead of args to pass hyperparameter", + ) + parser.add_argument( + "--output_config", action="store_true", help="output command line args to given .toml file" + ) + + if support_dreambooth: + # DreamBooth training + parser.add_argument( + "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images" + ) + + +def verify_training_args(args: argparse.Namespace): + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected") + + +def add_dataset_arguments( + parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool +): + # dataset common + parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images") + parser.add_argument( + "--shuffle_caption", action="store_true", help="shuffle comma-separated caption" + ) + parser.add_argument( + "--caption_extension", type=str, default=".caption", help="extension of caption files" + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption files (backward compatibility)", + ) + parser.add_argument( + "--keep_tokens", + type=int, + default=0, + help="keep heading N tokens when shuffling caption tokens (token means comma separated strings)", + ) + parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation") + parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation") + parser.add_argument( + "--face_crop_aug_range", + type=str, + default=None, + help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0)", + ) + parser.add_argument( + "--random_crop", + action="store_true", + help="enable random crop (for style training in face-centered crop augmentation)", + ) + parser.add_argument( + "--debug_dataset", action="store_true", help="show images for debugging (do not train)" + ) + parser.add_argument( + "--resolution", + type=str, + default=None, + help="resolution in training ('size' or 'width,height')", + ) + parser.add_argument( + "--cache_latents", + action="store_true", + help="cache latents to reduce memory (augmentations must be disabled)", + ) + parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents") + parser.add_argument( + "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training" + ) + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets") + parser.add_argument( + "--bucket_reso_steps", + type=int, + default=64, + help="steps of resolution for buckets, divisible by 8 is recommended", + ) + parser.add_argument( + "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling" + ) + + parser.add_argument( + "--token_warmup_min", + type=int, + default=1, + help="start learning at N tags (token means comma separated strinfloatgs)", + ) + + parser.add_argument( + "--token_warmup_step", + type=float, + default=0, + help="tag length reaches maximum on N steps (or N*max_train_steps if N<1)", + ) + + if support_caption_dropout: + # Textual Inversion はcaptionのdropoutをsupportしない + # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに + parser.add_argument( + "--caption_dropout_rate", type=float, default=0.0, help="Rate out dropout caption(0.0~1.0)" + ) + parser.add_argument( + "--caption_dropout_every_n_epochs", + type=int, + default=0, + help="Dropout all captions every N epochs", + ) + parser.add_argument( + "--caption_tag_dropout_rate", + type=float, + default=0.0, + help="Rate out dropout comma separated tokens(0.0~1.0)", + ) + + if support_dreambooth: + # DreamBooth dataset + parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images") + + if support_caption: + # caption dataset + parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset") + parser.add_argument( + "--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions" + ) + + +def add_sd_saving_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--save_model_as", + type=str, + default=None, + choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], + help="format to save the model (default is same to original)", + ) + parser.add_argument( + "--use_safetensors", + action="store_true", + help="use safetensors format to save (if save_model_as is not specified)", + ) + + +def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): + if not args.config_file: + return args + + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + + if args.output_config: + # check if config file exists + if os.path.exists(config_path): + print(f"Config file already exists. Aborting...") + exit(1) + + # convert args to dictionary + args_dict = vars(args) + + # remove unnecessary keys + for key in ["config_file", "output_config"]: + if key in args_dict: + del args_dict[key] + + # get default args from parser + default_args = vars(parser.parse_args([])) + + # remove default values: cannot use args_dict.items directly because it will be changed during iteration + for key, value in list(args_dict.items()): + if key in default_args and value == default_args[key]: + del args_dict[key] + + # convert Path to str in dictionary + for key, value in args_dict.items(): + if isinstance(value, pathlib.Path): + args_dict[key] = str(value) + + # convert to toml and output to file + with open(config_path, "w") as f: + toml.dump(args_dict, f) + + print(f"Saved config file") + exit(0) + + if not os.path.exists(config_path): + print(f"{config_path} not found.") + exit(1) + + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + # combine all sections into one + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + # if value is not dict, save key and value as is + if not isinstance(section_dict, dict): + ignore_nesting_dict[section_name] = section_dict + continue + + # if value is dict, save all key and value into one dict + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = os.path.splitext(args.config_file)[0] + print(args.config_file) + + return args + + +def get_optimizer(args, trainable_params): + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" + + optimizer_type = args.optimizer_type + if args.use_8bit_adam: + assert ( + not args.use_lion_optimizer + ), "both option use_8bit_adam and use_lion_optimizer are specified" + assert ( + optimizer_type is None or optimizer_type == "" + ), "both option use_8bit_adam and optimizer_type are specified" + optimizer_type = "AdamW8bit" + + elif args.use_lion_optimizer: + assert ( + optimizer_type is None or optimizer_type == "" + ), "both option use_lion_optimizer and optimizer_type are specified" + optimizer_type = "Lion" + + if optimizer_type is None or optimizer_type == "": + optimizer_type = "AdamW" + optimizer_type = optimizer_type.lower() + + # decompose the arguments + optimizer_kwargs = {} + if args.optimizer_args is not None and len(args.optimizer_args) > 0: + for arg in args.optimizer_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + + # value = value.split(",") + # for i in range(len(value)): + # if value[i].lower() == "true" or value[i].lower() == "false": + # value[i] = value[i].lower() == "true" + # else: + # value[i] = ast.float(value[i]) + # if len(value) == 1: + # value = value[0] + # else: + # value = tuple(value) + + optimizer_kwargs[key] = value + # print("optkwargs:", optimizer_kwargs) + + lr = args.learning_rate + + if optimizer_type == "AdamW8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes") + print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes") + print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print( + f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9" + ) + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = bnb.optim.SGD8bit + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "Lion".lower(): + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch") + print(f"use Lion optimizer | {optimizer_kwargs}") + optimizer_class = lion_pytorch.Lion + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov".lower(): + print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print(f"SGD with Nesterov must be with momentum, set momentum to 0.9") + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = torch.optim.SGD + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "DAdaptation".lower(): + try: + import dadaptation + except ImportError: + raise ImportError("No dadaptation") + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using dadaptation, set learning rate around 1.0" + ) + print("recommend option: lr=1.0") + if lr_count > 1: + print( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect" + ) + + optimizer_class = dadaptation.DAdaptAdam + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Adafactor".lower(): + # Check the arguments and correct them as appropriate + if "relative_step" not in optimizer_kwargs: + optimizer_kwargs["relative_step"] = True # default + if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): + print(f"set relative_step to True because warmup_init is True") + optimizer_kwargs["relative_step"] = True + print(f"use Adafactor optimizer | {optimizer_kwargs}") + + if optimizer_kwargs["relative_step"]: + print(f"relative_step is true") + if lr != 0.0: + print(f"learning rate is used as initial_lr") + args.learning_rate = None + + # Processing when trainable_params is group + if type(trainable_params) == list and type(trainable_params[0]) == dict: + has_group_lr = False + for group in trainable_params: + p = group.pop("lr", None) + has_group_lr = has_group_lr or (p is not None) + + if has_group_lr: + # Disable args just in case TODO This is not very desirable as the dependencies are reversed. + print(f"unet_lr and text_encoder_lr are ignored") + args.unet_lr = None + args.text_encoder_lr = None + + if args.lr_scheduler != "adafactor": + print(f"use adafactor_scheduler") + args.lr_scheduler = f"adafactor:{lr}" + + lr = None + else: + if args.max_grad_norm != 0.0: + print( + f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0" + ) + if args.lr_scheduler != "constant_with_warmup": + print(f"constant_with_warmup will be good") + if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: + print(f"clip_threshold=1.0 will be good") + + optimizer_class = transformers.optimization.Adafactor + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "AdamW".lower(): + print(f"use AdamW optimizer | {optimizer_kwargs}") + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + else: + optimizer_type = args.optimizer_type + print(f"use {optimizer_type} | {optimizer_kwargs}") + if "." not in optimizer_type: + optimizer_module = torch.optim + else: + values = optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + optimizer_type = values[-1] + + optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ + optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + + return optimizer_name, optimizer_args, optimizer + + +# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler +# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6 +# Which is a newer release of diffusers than currently packaged with sd-scripts +# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts + + +def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): + """ + Unified API to get any scheduler from its name. + """ + name = args.lr_scheduler + num_warmup_steps: Optional[int] = args.lr_warmup_steps + num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps + num_cycles = args.lr_scheduler_num_cycles + power = args.lr_scheduler_power + + lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs + if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: + for arg in args.lr_scheduler_args: + key, value = arg.split("=") + + value = ast.literal_eval(value) + + lr_scheduler_kwargs[key] = value + + def wrap_check_needless_num_warmup_steps(return_vals): + if num_warmup_steps is not None and num_warmup_steps != 0: + raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.") + return return_vals + + # using any lr_scheduler from other library + if args.lr_scheduler_type: + lr_scheduler_type = args.lr_scheduler_type + print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + if "." not in lr_scheduler_type: # default to use torch.optim + lr_scheduler_module = torch.optim.lr_scheduler + else: + values = lr_scheduler_type.split(".") + lr_scheduler_module = importlib.import_module(".".join(values[:-1])) + lr_scheduler_type = values[-1] + lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) + lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + return wrap_check_needless_num_warmup_steps(lr_scheduler) + + if name.startswith("adafactor"): + assert ( + type(optimizer) == transformers.optimization.Adafactor + ), f"adafactor scheduler must be used with Adafactor optimizer" + initial_lr = float(name.split(":")[1]) + # print("adafactor scheduler init lr", initial_lr) + return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) + + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return wrap_check_needless_num_warmup_steps(schedule_func(optimizer)) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + + +def load_tokenizer(args: argparse.Namespace): + print("prepare tokenizer") + original_path = TOKENIZER_PATH + + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + print(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) + + if tokenizer is None: + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(original_path) + + if hasattr(args, "max_token_length") and args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") + + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + print(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + return tokenizer + + +def prepare_accelerator(args: argparse.Namespace): + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=log_with, + logging_dir=logging_dir, + ) + + # accelerate + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False + + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) + + return accelerator, unwrap_model + + +def prepare_dtype(args: argparse.Namespace): + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 + + return weight_dtype, save_dtype + + +def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): + name_or_path = args.pretrained_model_name_or_path + name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path + load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + if load_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device) + else: + # Diffusers model is loaded to CPU + print("load Diffusers pretrained models") + try: + pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) + except EnvironmentError as ex: + print( + f"model is not found as a file or in Hugging Face, perhaps file name is wrong?" + ) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + del pipe + + # VAE + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + print("additional VAE loaded") + + return text_encoder, vae, unet, load_stable_diffusion_format + + +def patch_accelerator_for_fp16_training(accelerator): + org_unscale_grads = accelerator.scaler._unscale_grads_ + + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + + +def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): + # with no_token_padding, the length is not max length, return result immediately + if input_ids.size()[-1] != tokenizer.model_max_length: + return text_encoder(input_ids)[0] + + b_size = input_ids.size()[0] + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 + + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] + else: + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if args.max_token_length is not None: + if args.v2: + # v2: Return the triplet of ... ... to ... + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] # After until before the last + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # Empty, i.e. ... pattern + chunk[j, 0] = chunk[j, 1] # Copy the value of the next + states_list.append(chunk) # After and before + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # Either or + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: Return triplet of ... to ... + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # After and before + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + + return encoder_hidden_states + + +def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): + model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") + return model_name, ckpt_name + + +def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): + saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs + if saving: + os.makedirs(args.output_dir, exist_ok=True) + save_func() + + if args.save_last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + remove_old_func(remove_epoch_no) + return saving + + +def save_sd_model_on_epoch_end( + args: argparse.Namespace, + accelerator, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + text_encoder, + unet, + vae, +): + epoch_no = epoch + 1 + model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) + + if save_stable_diffusion_format: + + def save_sd(): + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae + ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) + + def remove_sd(old_epoch_no): + _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + save_func = save_sd + remove_old_func = remove_sd + else: + + def save_du(): + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + print(f"saving model: {out_dir}") + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, out_dir, "/" + model_name) + + def remove_du(old_epoch_no): + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) + if os.path.exists(out_dir_old): + print(f"removing old model: {out_dir_old}") + shutil.rmtree(out_dir_old) + + save_func = save_du + remove_old_func = remove_du + + saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) + if saving and args.save_state: + save_state_on_epoch_end(args, accelerator, model_name, epoch_no) + + +def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): + print("saving state.") + state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) + accelerator.save_state(state_dir) + if args.save_state_to_huggingface: + print("uploading state to huggingface.") + huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) + + last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs + if last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + print(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) + + +def save_state_on_train_end(args: argparse.Namespace, accelerator): + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) + accelerator.save_state(state_dir) + if args.save_state_to_huggingface: + print("uploading last state to huggingface.") + huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) + + +def save_sd_model_on_train_end( + args: argparse.Namespace, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + text_encoder, + unet, + vae, +): + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + + if save_stable_diffusion_format: + os.makedirs(args.output_dir, exist_ok=True) + + ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae + ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) + else: + out_dir = os.path.join(args.output_dir, model_name) + os.makedirs(out_dir, exist_ok=True) + + print(f"save trained model as Diffusers to {out_dir}") + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) + + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + + +def sample_images( + accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None +): + """ + Modified version of StableDiffusionLongPromptWeightingPipeline, so it supports clip skip and prompt weighting. + """ + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + print(f"generating sample images at step: {steps}") + if not os.path.isfile(args.sample_prompts): + print(f"Missing prompt file: {args.sample_prompts}") + return + + org_vae_device = vae.device # CPUにいるはず + vae.to(device) + + # read prompts + with open(args.sample_prompts, "rt", encoding="utf-8") as f: + prompts = f.readlines() + + # scheduler + sched_init_args = {} + + if args.sample_sampler == "euler" or args.sample_sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + else: + scheduler_cls = EulerAncestralDiscreteScheduler + + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + scheduler.config.clip_sample = True + + pipeline = StableDiffusionLongPromptWeightingPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + clip_skip=args.clip_skip, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + pipeline.to(device) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() + + with torch.no_grad(): + with accelerator.autocast(): + for i, prompt in enumerate(prompts): + if not accelerator.is_main_process: + continue + prompt = prompt.strip() + if len(prompt) == 0 or prompt[0] == "#": + continue + + # subset of gen_img_diffusers + prompt_args = prompt.split(" --") + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue + + except ValueError as ex: + print(f"Exception in parsing") + print(ex) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"prompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + image = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + ).images[0] + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + ) + + image.save(os.path.join(save_dir, img_filename)) + + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() + + torch.set_rng_state(rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + vae.to(org_vae_device) + + +# endregion + +# region + + +class ImageLoadingDataset(torch.utils.data.Dataset): + def __init__(self, image_paths): + self.images = image_paths + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img_path = self.images[idx] + + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor_pil = transforms.functional.pil_to_tensor(image) + except Exception as e: + print(f"Could not load image path") + return None + + return (tensor_pil, img_path) + + +# endregion + + +# collate_fn用 epoch,stepはmultiprocessing.Value +class collater_class: + def __init__(self, epoch, step, dataset): + self.current_epoch = epoch + self.current_step = step + self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing + + def __call__(self, examples): + worker_info = torch.utils.data.get_worker_info() + # worker_info is None in the main process + if worker_info is not None: + dataset = worker_info.dataset + else: + dataset = self.dataset + + # set epoch and step + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] diff --git a/SystemCode/src/main/model/utils/common.py b/SystemCode/src/main/model/utils/common.py new file mode 100644 index 0000000..4514b3b --- /dev/null +++ b/SystemCode/src/main/model/utils/common.py @@ -0,0 +1,4 @@ +import os + +def get_file_type(path): + return os.path.splitext(path)[1].lower() \ No newline at end of file diff --git a/SystemCode/src/main/model/utils/conversions.py b/SystemCode/src/main/model/utils/conversions.py new file mode 100644 index 0000000..2136e9c --- /dev/null +++ b/SystemCode/src/main/model/utils/conversions.py @@ -0,0 +1,610 @@ +# Key conversions +""" +Utils to convert Stable Diffusion LDM checkpoints to Diffusers format. +Adapted from the diffusers module, see: https://github.com/huggingface/diffusers/blob/main/scripts/convert_ldm_original_checkpoint_to_diffusers.py +""" + + +""" +Text encoder conversions +""" +def convert_text_encoder_key(state_dict): + """Replace the text model keys for different types of checpoints""" + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), + ] + + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from) :] + key_reps.append((key, new_key)) + + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] + return state_dict + +def convert_ldm_clip_checkpoint_v1(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + return text_model_dict + + +""" +ResNet conversions +""" +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + +def convert_ldm_unet_checkpoint(checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + return new_checkpoint + + +""" +VAE conversions +""" +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] + + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + + return new_state_dict + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)} + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_unet_state_dict_to_sd(unet_state_dict): + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + return new_state_dict +