diff --git a/.gitignore b/.gitignore index 61bac16..1b87291 100644 --- a/.gitignore +++ b/.gitignore @@ -131,4 +131,5 @@ dmypy.json # Extra stuff to ignore dreams images -run.py \ No newline at end of file +run.py +examples \ No newline at end of file diff --git a/README.md b/README.md index 7b5703b..d1253bf 100644 --- a/README.md +++ b/README.md @@ -45,21 +45,69 @@ huggingface-cli login #### Programatic Usage ```python -from stable_diffusion_videos import walk - -walk( +from stable_diffusion_videos import StableDiffusionWalkPipeline +from diffusers.schedulers import LMSDiscreteScheduler +import torch + +pipeline = StableDiffusionWalkPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + use_auth_token=True, + torch_dtype=torch.float16, + revision="fp16", + scheduler=LMSDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) +).to("cuda") + +video_path = pipeline.walk( prompts=['a cat', 'a dog'], seeds=[42, 1337], + num_interpolation_steps=3, + height=512, # use multiples of 64 if > 512. Multiples of 8 if < 512. + width=512, # use multiples of 64 if > 512. Multiples of 8 if < 512. output_dir='dreams', # Where images/videos will be saved name='animals_test', # Subdirectory of output_dir where images/videos will be saved guidance_scale=8.5, # Higher adheres to prompt more, lower lets model take the wheel - num_interpolation_steps=5, # Change to 60-200 for better results...3-5 for testing - num_inference_steps=50, # Number of diffusion steps per image generated. 50 is good default. - scheduler='klms', # One of: "klms", "default", "ddim" - disable_tqdm=False, # Set to True to disable tqdm progress bar - make_video=True, # If false, just save images - use_lerp_for_text=True, # Use lerp for text embeddings instead of slerp - do_loop=False, # Change to True if you want last prompt to loop back to first prompt + num_inference_steps=50, # Number of diffusion steps per image generated. 50 is good default +) +``` + +*New!* Music can be added to the video by providing a path to an audio file. The audio will inform the rate of interpolation so the videos move to the beat 🎶 + +```python +from stable_diffusion_videos import StableDiffusionWalkPipeline +from diffusers.schedulers import LMSDiscreteScheduler +import torch + +pipeline = StableDiffusionWalkPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + use_auth_token=True, + torch_dtype=torch.float16, + revision="fp16", + scheduler=LMSDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) +).to("cuda") + + +# Seconds in the song. +audio_offsets = [146, 148] +fps = 30 # Use lower values for testing (5 or 10), higher values for better quality (30 or 60) + +# Convert seconds to frames +num_interpolation_steps = [(b-a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])] + +video_path = pipeline.walk( + prompts=['a cat', 'a dog'], + seeds=[42, 1337], + num_interpolation_steps=num_interpolation_steps, + audio_filepath='audio.mp3', + audio_start_sec=audio_offsets[0], + height=512, # use multiples of 64 if > 512. Multiples of 8 if < 512. + width=512, # use multiples of 64 if > 512. Multiples of 8 if < 512. + output_dir='dreams', # Where images/videos will be saved + guidance_scale=7.5, # Higher adheres to prompt more, lower lets model take the wheel + num_inference_steps=50, # Number of diffusion steps per image generated. 50 is good default ) ``` @@ -97,9 +145,7 @@ pip install realesrgan Then, you'll be able to use `upsample=True` in the `walk` function, like this: ```python -from stable_diffusion_videos import walk - -walk(['a cat', 'a dog'], [234, 345], upsample=True) +pipeline.walk(['a cat', 'a dog'], [234, 345], upsample=True) ``` The above may cause you to run out of VRAM. No problem, you can do upsampling separately. diff --git a/requirements.txt b/requirements.txt index decdfb2..6ff0184 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ transformers -diffusers==0.3.0 +diffusers==0.4.0 scipy fire -gradio \ No newline at end of file +gradio +librosa +av \ No newline at end of file diff --git a/stable_diffusion_videos/__init__.py b/stable_diffusion_videos/__init__.py index e036ebb..cb747fe 100644 --- a/stable_diffusion_videos/__init__.py +++ b/stable_diffusion_videos/__init__.py @@ -61,9 +61,7 @@ def _attach(package_name, submodules=None, submod_attrs=None): else: submodules = set(submodules) - attr_to_modules = { - attr: mod for mod, attrs in submod_attrs.items() for attr in attrs - } + attr_to_modules = {attr: mod for mod, attrs in submod_attrs.items() for attr in attrs} __all__ = list(submodules | attr_to_modules.keys()) @@ -96,28 +94,22 @@ def __dir__(): return __getattr__, __dir__, list(__all__) - __getattr__, __dir__, __all__ = _attach( __name__, submodules=[], submod_attrs={ - "commands.user": ["notebook_login"], "app": [ "interface", + "pipeline", ], "stable_diffusion_pipeline": [ - "StableDiffusionPipeline", + "StableDiffusionWalkPipeline", "NoCheck", + "make_video_pyav", + "get_timesteps_arr", ], - "stable_diffusion_walk": [ - "walk", - "SCHEDULERS", - "pipeline", - ], - "upsampling": [ - "PipelineRealESRGAN" - ] + "upsampling": ["PipelineRealESRGAN"], }, ) -__version__ = "0.4.0" \ No newline at end of file +__version__ = "0.5.0" diff --git a/stable_diffusion_videos/app.py b/stable_diffusion_videos/app.py index 73bd1d6..04067c7 100644 --- a/stable_diffusion_videos/app.py +++ b/stable_diffusion_videos/app.py @@ -3,34 +3,38 @@ import gradio as gr import torch -from .stable_diffusion_walk import SCHEDULERS, pipeline, walk +from .stable_diffusion_pipeline import StableDiffusionWalkPipeline +from .upsampling import RealESRGANModel + +pipeline = StableDiffusionWalkPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + use_auth_token=True, + torch_dtype=torch.float16, + revision="fp16", +).to("cuda") def fn_images( prompt, seed, - scheduler, guidance_scale, num_inference_steps, - disable_tqdm, upsample, ): if upsample: - from .upsampling import PipelineRealESRGAN - - upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan') + if getattr(pipeline, "upsampler", None) is None: + pipeline.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan") + pipeline.upsampler.to(pipeline.device) - pipeline.set_progress_bar_config(disable=disable_tqdm) - pipeline.scheduler = SCHEDULERS[scheduler] # klms, default, ddim with torch.autocast("cuda"): img = pipeline( prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=torch.Generator(device=pipeline.device).manual_seed(seed), - output_type='pil' if not upsample else 'numpy', + output_type="pil" if not upsample else "numpy", )["sample"][0] - return img if not upsample else upsampling_pipeline(img) + return img if not upsample else pipeline.upsampler(img) def fn_videos( @@ -38,13 +42,9 @@ def fn_videos( seed_1, prompt_2, seed_2, - scheduler, guidance_scale, num_inference_steps, num_interpolation_steps, - do_loop, - disable_tqdm, - use_lerp_for_text, output_dir, upsample, ): @@ -54,20 +54,15 @@ def fn_videos( prompts = [x for x in prompts if x.strip()] seeds = seeds[: len(prompts)] - video_path = walk( - do_loop=do_loop, - make_video=True, + video_path = pipeline.walk( guidance_scale=guidance_scale, prompts=prompts, seeds=seeds, num_interpolation_steps=num_interpolation_steps, num_inference_steps=num_inference_steps, - use_lerp_for_text=use_lerp_for_text, output_dir=output_dir, name=time.strftime("%Y%m%d-%H%M%S"), - scheduler=scheduler, - disable_tqdm=disable_tqdm, - upsample=upsample + upsample=upsample, ) return video_path @@ -76,21 +71,15 @@ def fn_videos( fn_videos, inputs=[ gr.Textbox("blueberry spaghetti"), - gr.Number(42, label='Seed 1', precision=0), + gr.Number(42, label="Seed 1", precision=0), gr.Textbox("strawberry spaghetti"), - gr.Number(42, label='Seed 2', precision=0), - gr.Dropdown(["klms", "ddim", "default"], value="klms"), + gr.Number(42, label="Seed 2", precision=0), gr.Slider(0.0, 20.0, 8.5), gr.Slider(1, 200, 50), gr.Slider(3, 240, 10), - gr.Checkbox(False), - gr.Checkbox(False), - gr.Checkbox(True), gr.Textbox( "dreams", - placeholder=( - "Folder where outputs will be saved. Each output will be saved in a new folder." - ), + placeholder=("Folder where outputs will be saved. Each output will be saved in a new folder."), ), gr.Checkbox(False), ], @@ -101,19 +90,15 @@ def fn_videos( fn_images, inputs=[ gr.Textbox("blueberry spaghetti"), - gr.Number(42, label='Seed', precision=0), - gr.Dropdown(["klms", "ddim", "default"], value="klms"), + gr.Number(42, label="Seed", precision=0), gr.Slider(0.0, 20.0, 8.5), gr.Slider(1, 200, 50), gr.Checkbox(False), - gr.Checkbox(False), ], outputs=gr.Image(type="pil"), ) -interface = gr.TabbedInterface( - [interface_images, interface_videos], ["Images!", "Videos!"] -) +interface = gr.TabbedInterface([interface_images, interface_videos], ["Images!", "Videos!"]) if __name__ == "__main__": interface.launch(debug=True) diff --git a/stable_diffusion_videos/stable_diffusion_pipeline.py b/stable_diffusion_videos/stable_diffusion_pipeline.py index b2ebb54..bb07264 100644 --- a/stable_diffusion_videos/stable_diffusion_pipeline.py +++ b/stable_diffusion_videos/stable_diffusion_pipeline.py @@ -1,20 +1,187 @@ import inspect -import warnings -from tqdm.auto import tqdm -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union +from pathlib import Path +from torchvision.transforms.functional import pil_to_tensor +import librosa +from PIL import Image +from torchvision.io import write_video +import numpy as np +import time +import json import torch from diffusers import ModelMixin +from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import \ - StableDiffusionSafetyChecker -from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler, - PNDMScheduler) +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.utils import deprecate, logging +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from torch import nn + +from .upsampling import RealESRGANModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_spec_norm(wav, sr, n_mels=512, hop_length=704): + """Obtain maximum value for each time-frame in Mel Spectrogram, + and normalize between 0 and 1 + + Borrowed from lucid sonic dreams repo. In there, they programatically determine hop length + but I really didn't understand what was going on so I removed it and hard coded the output. + """ + + # Generate Mel Spectrogram + spec_raw = librosa.feature.melspectrogram(y=wav, sr=sr, n_mels=n_mels, hop_length=hop_length) + + # Obtain maximum value per time-frame + spec_max = np.amax(spec_raw, axis=0) + + # Normalize all values between 0 and 1 + spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max) + + return spec_norm + + +def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=(1.0, 5.0)): + """Get the array that will be used to determine how much to interpolate between images. + + Normally, this is just a linspace between 0 and 1 for the number of frames to generate. In this case, + we want to use the amplitude of the audio to determine how much to interpolate between images. + + So, here we: + 1. Load the audio file + 2. Split the audio into harmonic and percussive components + 3. Get the normalized amplitude of the percussive component, resized to the number of frames + 4. Get the cumulative sum of the amplitude array + 5. Normalize the cumulative sum between 0 and 1 + 6. Return the array + + I honestly have no clue what I'm doing here. Suggestions welcome. + """ + y, sr = librosa.load(audio_filepath, offset=offset, duration=duration) + wav_harmonic, wav_percussive = librosa.effects.hpss(y, margin=margin) + + # Apparently n_mels is supposed to be input shape but I don't think it matters here? + frame_duration = int(sr / fps) + wav_norm = get_spec_norm(wav_percussive, sr, n_mels=512, hop_length=frame_duration) + amplitude_arr = np.resize(wav_norm, int(duration * fps)) + T = np.cumsum(amplitude_arr) + T /= T[-1] + T[0] = 0.0 + return T + + +def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + """helper function to spherically interpolate two arrays v1 v2""" + + if not isinstance(v0, np.ndarray): + inputs_are_torch = True + input_device = v0.device + v0 = v0.cpu().numpy() + v1 = v1.cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + if inputs_are_torch: + v2 = torch.from_numpy(v2).to(input_device) + + return v2 + + +def make_video_pyav( + frames_or_frame_dir: Union[str, Path, torch.Tensor], + audio_filepath: Union[str, Path] = None, + fps: int = 30, + audio_offset: int = 0, + audio_duration: int = 2, + sr: int = 22050, + output_filepath: Union[str, Path] = "output.mp4", + glob_pattern: str = "*.png", +): + """ + TODO - docstring here + + frames_or_frame_dir: (Union[str, Path, torch.Tensor]): + Either a directory of images, or a tensor of shape (T, C, H, W) in range [0, 255]. + """ + + # Torchvision write_video doesn't support pathlib paths + output_filepath = str(output_filepath) + + if isinstance(frames_or_frame_dir, (str, Path)): + frames = None + for img in sorted(Path(frames_or_frame_dir).glob(glob_pattern)): + frame = pil_to_tensor(Image.open(img)).unsqueeze(0) + frames = frame if frames is None else torch.cat([frames, frame]) + else: + + frames = frames_or_frame_dir + + # TCHW -> THWC + frames = frames.permute(0, 2, 3, 1) + + if audio_filepath: + # Read audio, convert to tensor + audio, sr = librosa.load(audio_filepath, sr=sr, mono=True, offset=audio_offset, duration=audio_duration) + audio_tensor = torch.tensor(audio).unsqueeze(0) + + write_video( + output_filepath, + frames, + fps=fps, + audio_array=audio_tensor, + audio_fps=sr, + audio_codec="aac", + options={"crf": "10", "pix_fmt": "yuv420p"}, + ) + else: + write_video(output_filepath, frames, fps=fps, options={"crf": "10", "pix_fmt": "yuv420p"}) + + return output_filepath + +class StableDiffusionWalkPipeline(DiffusionPipeline): + r""" + Pipeline for generating videos by interpolating Stable Diffusion's latent space. + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ -class StableDiffusionPipeline(DiffusionPipeline): def __init__( self, vae: AutoencoderKL, @@ -26,7 +193,21 @@ def __init__( feature_extractor: CLIPFeatureExtractor, ): super().__init__() - scheduler = scheduler.set_format("pt") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -36,14 +217,12 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention in several steps. This is useful to save some memory in exchange for a small speed decrease. - Args: slice_size (`str` or `int`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If @@ -68,29 +247,79 @@ def disable_attention_slicing(self): def __call__( self, prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, - text_embeddings: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", - **kwargs, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + text_embeddings: Optional[torch.FloatTensor] = None, ): - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and" - " will be removed in v0.3.0. Consider using `pipe.to(torch_device)`" - " instead." - ) + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + text_embeddings(`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) if text_embeddings is None: if isinstance(prompt, str): @@ -98,25 +327,25 @@ def __call__( elif isinstance(prompt, list): batch_size = len(prompt) else: - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) - - if height % 8 != 0 or width % 8 != 0: - raise ValueError( - "`height` and `width` have to be divisible by 8 but are" - f" {height} and {width}." - ) + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") # get prompt text embeddings - text_input = self.tokenizer( + text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, - truncation=True, return_tensors="pt", ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] else: batch_size = text_embeddings.shape[0] @@ -126,17 +355,14 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - # max_length = text_input.input_ids.shape[-1] - max_length = 77 # self.tokenizer.model_max_length + # HACK - Not setting text_input_ids here when walking, so hard coding to max length of tokenizer + # TODO - Determine if this is OK to do + # max_length = text_input_ids.shape[-1] + max_length = self.tokenizer.model_max_length uncond_input = self.tokenizer( - [""] * batch_size, - padding="max_length", - max_length=max_length, - return_tensors="pt", + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder( - uncond_input.input_ids.to(self.device) - )[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -144,96 +370,401 @@ def __call__( text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: latents = torch.randn( latents_shape, generator=generator, - device=self.device, + device=latents_device, + dtype=text_embeddings.dtype, ) else: if latents.shape != latents_shape: - raise ValueError( - f"Unexpected latents shape, got {latents.shape}, expected" - f" {latents_shape}" - ) - latents = latents.to(self.device) + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(latents_device) # set timesteps - accepts_offset = "offset" in set( - inspect.signature(self.scheduler.set_timesteps).parameters.keys() - ) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 + self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) - # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=text_embeddings - )["sample"] + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step( - noise_pred, i, latents, **extra_step_kwargs - )["prev_sample"] - else: - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - )["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) - # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - safety_cheker_input = self.feature_extractor( - self.numpy_to_pil(image), return_tensors="pt" - ).to(self.device) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_cheker_input.pixel_values + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) ) if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image, "nsfw_content_detected": has_nsfw_concept} + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size): + embeds_a = self.embed_text(prompt_a) + embeds_b = self.embed_text(prompt_b) + latents_a = torch.randn( + noise_shape, + device=self.device, + generator=torch.Generator(device=self.device).manual_seed(seed_a), + ) + latents_b = torch.randn( + noise_shape, + device=self.device, + generator=torch.Generator(device=self.device).manual_seed(seed_b), + ) + + batch_idx = 0 + embeds_batch, noise_batch = None, None + for i, t in enumerate(T): + embeds = torch.lerp(embeds_a, embeds_b, t) + noise = slerp(float(t), latents_a, latents_b) + + embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds]) + noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise]) + batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0] + if not batch_is_ready: + continue + yield batch_idx, embeds_batch, noise_batch + batch_idx += 1 + del embeds_batch, noise_batch + torch.cuda.empty_cache() + embeds_batch, noise_batch = None, None + + def generate_interpolation_clip( + self, + prompt_a: str, + prompt_b: str, + seed_a: int, + seed_b: int, + num_interpolation_steps: int = 5, + save_path: Union[str, Path] = "outputs/", + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + eta: float = 0.0, + height: int = 512, + width: int = 512, + upsample: bool = False, + batch_size: int = 1, + image_file_ext: str = ".png", + T: np.ndarray = None, + skip: int = 0, + ): + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + + T = T if T is not None else np.linspace(0.0, 1.0, num_interpolation_steps) + if T.shape[0] != num_interpolation_steps: + raise ValueError(f"Unexpected T shape, got {T.shape}, expected dim 0 to be {num_interpolation_steps}") + + if upsample: + if getattr(self, "upsampler", None) is None: + self.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan") + self.upsampler.to(self.device) + + batch_generator = self.generate_inputs( + prompt_a, + prompt_b, + seed_a, + seed_b, + (1, self.unet.in_channels, height // 8, width // 8), + T[skip:], + batch_size, + ) + + frame_index = skip + for _, embeds_batch, noise_batch in batch_generator: + with torch.autocast("cuda"): + outputs = self( + latents=noise_batch, + text_embeddings=embeds_batch, + height=height, + width=width, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type="pil" if not upsample else "numpy", + )["sample"] + + for image in outputs: + frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index) + image = image if not upsample else self.upsampler(image) + image.save(frame_filepath) + frame_index += 1 + + def walk( + self, + prompts: Optional[List[str]] = None, + seeds: Optional[List[int]] = None, + num_interpolation_steps: Optional[Union[int, List[int]]] = 5, # int or list of int + output_dir: Optional[str] = "./dreams", + name: Optional[str] = None, + image_file_ext: Optional[str] = ".png", + fps: Optional[int] = 30, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + height: Optional[int] = 512, + width: Optional[int] = 512, + upsample: Optional[bool] = False, + batch_size: Optional[int] = 1, + resume: Optional[bool] = False, + audio_filepath: str = None, + audio_start_sec: Optional[Union[int, float]] = None, + ): + """Generate a video from a sequence of prompts and seeds. Optionally, add audio to the + video to interpolate to the intensity of the audio. + + Args: + prompts (Optional[List[str]], optional): + list of text prompts. Defaults to None. + seeds (Optional[List[int]], optional): + list of random seeds corresponding to prompts. Defaults to None. + num_interpolation_steps (Union[int, List[int]], *optional*): + How many interpolation steps between each prompt. Defaults to None. + output_dir (Optional[str], optional): + Where to save the video. Defaults to './dreams'. + name (Optional[str], optional): + Name of the subdirectory of output_dir. Defaults to None. + image_file_ext (Optional[str], *optional*, defaults to '.png'): + The extension to use when writing video frames. + fps (Optional[int], *optional*, defaults to 30): + The frames per second in the resulting output videos. + num_inference_steps (Optional[int], *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (Optional[float], *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (Optional[float], *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + height (Optional[int], *optional*, defaults to 512): + height of the images to generate. + width (Optional[int], *optional*, defaults to 512): + width of the images to generate. + upsample (Optional[bool], *optional*, defaults to False): + When True, upsamples images with realesrgan. + batch_size (Optional[int], *optional*, defaults to 1): + Number of images to generate at once. + resume (Optional[bool], *optional*, defaults to False): + When True, resumes from the last frame in the output directory based + on available prompt config. Requires you to provide the `name` argument. + audio_filepath (str, *optional*, defaults to None): + Optional path to an audio file to influence the interpolation rate. + audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0): + Global start time of the provided audio_filepath. + + This function will create sub directories for each prompt and seed pair. + + For example, if you provide the following prompts and seeds: + + ``` + prompts = ['a', 'b', 'c'] + seeds = [1, 2, 3] + num_interpolation_steps = 5 + output_dir = 'output_dir' + name = 'name' + fps = 5 + ``` + + Then the following directories will be created: + + ``` + output_dir + ├── name + │ ├── name_000000 + │ │ ├── frame000000.png + │ │ ├── ... + │ │ ├── frame000004.png + │ │ ├── name_000000.mp4 + │ ├── name_000001 + │ │ ├── frame000000.png + │ │ ├── ... + │ │ ├── frame000004.png + │ │ ├── name_000001.mp4 + │ ├── ... + │ ├── name.mp4 + | |── prompt_config.json + ``` + + Returns: + str: The resulting video filepath. This video includes all sub directories' video clips. + """ + + output_path = Path(output_dir) + + name = name or time.strftime("%Y%m%d-%H%M%S") + save_path_root = output_path / name + save_path_root.mkdir(parents=True, exist_ok=True) + + # Where the final video of all the clips combined will be saved + output_filepath = save_path_root / f"{name}.mp4" + + # If using same number of interpolation steps between, we turn into list + if not resume and isinstance(num_interpolation_steps, int): + num_interpolation_steps = [num_interpolation_steps] * (len(prompts) - 1) + + if not resume and audio_filepath: + audio_start_sec = audio_start_sec or 0 + + # Save/reload prompt config + prompt_config_path = save_path_root / "prompt_config.json" + if not resume: + prompt_config_path.write_text( + json.dumps( + dict( + prompts=prompts, + seeds=seeds, + num_interpolation_steps=num_interpolation_steps, + fps=fps, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + eta=eta, + upsample=upsample, + height=height, + width=width, + audio_filepath=audio_filepath, + audio_start_sec=audio_start_sec, + ), + indent=2, + sort_keys=False, + ) + ) + else: + data = json.load(open(prompt_config_path)) + prompts = data["prompts"] + seeds = data["seeds"] + num_interpolation_steps = data["num_interpolation_steps"] + fps = data["fps"] + num_inference_steps = data["num_inference_steps"] + guidance_scale = data["guidance_scale"] + eta = data["eta"] + upsample = data["upsample"] + height = data["height"] + width = data["width"] + audio_filepath = data["audio_filepath"] + audio_start_sec = data["audio_start_sec"] + + for i, (prompt_a, prompt_b, seed_a, seed_b, num_step) in enumerate( + zip(prompts, prompts[1:], seeds, seeds[1:], num_interpolation_steps) + ): + # {name}_000000 / {name}_000001 / ... + save_path = save_path_root / f"{name}_{i:06d}" + + # Where the individual clips will be saved + step_output_filepath = save_path / f"{name}_{i:06d}.mp4" + + # Determine if we need to resume from a previous run + skip = 0 + if resume: + if step_output_filepath.exists(): + print(f"Skipping {save_path} because frames already exist") + continue + + existing_frames = sorted(save_path.glob(f"*{image_file_ext}")) + if existing_frames: + skip = int(existing_frames[-1].stem[-6:]) + 1 + if skip + 1 >= num_step: + print(f"Skipping {save_path} because frames already exist") + continue + print(f"Resuming {save_path.name} from frame {skip}") + + self.generate_interpolation_clip( + prompt_a, + prompt_b, + seed_a, + seed_b, + num_interpolation_steps=num_step, + save_path=save_path, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + eta=eta, + height=height, + width=width, + upsample=upsample, + batch_size=batch_size, + skip=skip, + T=get_timesteps_arr( + audio_filepath, + offset=audio_start_sec + (i * num_step / fps), + duration=num_step / fps, + fps=fps, + margin=(1.0, 5.0), + ) + if audio_filepath + else None, + ) + make_video_pyav( + save_path, + audio_filepath=audio_filepath, + fps=fps, + output_filepath=step_output_filepath, + glob_pattern=f"*{image_file_ext}", + audio_offset=audio_start_sec + (i * num_step / fps) if audio_start_sec else 0, + audio_duration=num_step / fps, + sr=44100, + ) + + return make_video_pyav( + save_path_root, + audio_filepath=audio_filepath, + fps=fps, + audio_offset=audio_start_sec, + audio_duration=sum(num_interpolation_steps) / fps, + output_filepath=output_filepath, + glob_pattern=f"**/*{image_file_ext}", + sr=44100, + ) def embed_text(self, text): """Helper to embed some text""" @@ -249,12 +780,42 @@ def embed_text(self, text): embed = self.text_encoder(text_input.input_ids.to(self.device))[0] return embed + @classmethod + def from_pretrained(cls, *args, tiled=False, **kwargs): + """Same as diffusers `from_pretrained` but with tiled option, which makes images tilable""" + if tiled: + + def patch_conv(**patch): + cls = nn.Conv2d + init = cls.__init__ + + def __init__(self, *args, **kwargs): + return init(self, *args, **kwargs, **patch) + + cls.__init__ = __init__ + + patch_conv(padding_mode="circular") + + return super().from_pretrained(*args, **kwargs) + class NoCheck(ModelMixin): """Can be used in place of safety checker. Use responsibly and at your own risk.""" + + def __init__(self): + super().__init__() + self.register_parameter(name="asdf", param=torch.nn.Parameter(torch.randn(3))) + + def forward(self, images=None, **kwargs): + return images, [False] + + +class NoCheck(ModelMixin): + """Can be used in place of safety checker. Use responsibly and at your own risk.""" + def __init__(self): super().__init__() - self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3))) + self.register_parameter(name="asdf", param=torch.nn.Parameter(torch.randn(3))) def forward(self, images=None, **kwargs): return images, [False] diff --git a/stable_diffusion_videos/stable_diffusion_walk.py b/stable_diffusion_videos/stable_diffusion_walk.py deleted file mode 100644 index 6e0d0b9..0000000 --- a/stable_diffusion_videos/stable_diffusion_walk.py +++ /dev/null @@ -1,317 +0,0 @@ -import json -import subprocess -from pathlib import Path -from typing import List, Optional, Union -from warnings import warn -import numpy as np -import torch -from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler, - PNDMScheduler) -from diffusers import ModelMixin - -from .stable_diffusion_pipeline import StableDiffusionPipeline - -pipeline = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - use_auth_token=True, - torch_dtype=torch.float16, - revision="fp16", -).to("cuda") - -default_scheduler = PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" -) -ddim_scheduler = DDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, -) -klms_scheduler = LMSDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" -) -SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) - - -def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): - """helper function to spherically interpolate two arrays v1 v2""" - - if not isinstance(v0, np.ndarray): - inputs_are_torch = True - input_device = v0.device - v0 = v0.cpu().numpy() - v1 = v1.cpu().numpy() - - dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) - if np.abs(dot) > DOT_THRESHOLD: - v2 = (1 - t) * v0 + t * v1 - else: - theta_0 = np.arccos(dot) - sin_theta_0 = np.sin(theta_0) - theta_t = theta_0 * t - sin_theta_t = np.sin(theta_t) - s0 = np.sin(theta_0 - theta_t) / sin_theta_0 - s1 = sin_theta_t / sin_theta_0 - v2 = s0 * v0 + s1 * v1 - - if inputs_are_torch: - v2 = torch.from_numpy(v2).to(input_device) - - return v2 - - -def make_video_ffmpeg(frame_dir, output_file_name='output.mp4', frame_filename="frame%06d.png", fps=30): - frame_ref_path = str(frame_dir / frame_filename) - video_path = str(frame_dir / output_file_name) - subprocess.call( - f"ffmpeg -r {fps} -i {frame_ref_path} -vcodec libx264 -crf 10 -pix_fmt yuv420p" - f" {video_path}".split() - ) - return video_path - - -def walk( - prompts: List[str] = ["blueberry spaghetti", "strawberry spaghetti"], - seeds: List[int] = [42, 123], - num_interpolation_steps: Union[int, List[int]] = 5, - output_dir: str = "dreams", - name: str = "berry_good_spaghetti", - height: int = 512, - width: int = 512, - guidance_scale: float = 7.5, - eta: float = 0.0, - num_inference_steps: int = 50, - do_loop: bool = False, - make_video: bool = False, - use_lerp_for_text: bool = False, - scheduler: str = "klms", # choices: default, ddim, klms - disable_tqdm: bool = False, - upsample: bool = False, - fps: int = 30, - less_vram: bool = False, - resume: bool = False, - batch_size: int = 1, - frame_filename_ext: str = '.png', - num_steps: Optional[int] = None -): - """Generate video frames/a video given a list of prompts and seeds. - - Args: - prompts (List[str], optional): List of . Defaults to ["blueberry spaghetti", "strawberry spaghetti"]. - seeds (List[int], optional): List of random seeds corresponding to given prompts. - num_interpolation_steps (Union[int, List[int]], optional): Number of steps to walk during each interpolation step. If int is provided, use same number of steps between each prompt. If a list is provided, the size of `num_interpolation_steps` should be `len(prompts) - 1`. Increase values to 60-200 for good results. Defaults to 5. - output_dir (str, optional): Root dir where images will be saved. Defaults to "dreams". - name (str, optional): Sub directory of output_dir to save this run's files. Defaults to "berry_good_spaghetti". - height (int, optional): Height of image to generate. Defaults to 512. - width (int, optional): Width of image to generate. Defaults to 512. - guidance_scale (float, optional): Higher = more adherance to prompt. Lower = let model take the wheel. Defaults to 7.5. - eta (float, optional): ETA. Defaults to 0.0. - num_inference_steps (int, optional): Number of diffusion steps. Defaults to 50. - do_loop (bool, optional): Whether to loop from last prompt back to first. Defaults to False. - make_video (bool, optional): Whether to make a video or just save the images. Defaults to False. - use_lerp_for_text (bool, optional): Use LERP instead of SLERP for text embeddings when walking. Defaults to True. - scheduler (str, optional): Which scheduler to use. Defaults to "klms". Choices are "default", "ddim", "klms". - disable_tqdm (bool, optional): Whether to turn off the tqdm progress bars. Defaults to False. - upsample (bool, optional): If True, uses Real-ESRGAN to upsample images 4x. Requires it to be installed - which you can do by running: `pip install git+https://github.com/xinntao/Real-ESRGAN.git`. Defaults to False. - fps (int, optional): The frames per second (fps) that you want the video to use. Does nothing if make_video is False. Defaults to 30. - less_vram (bool, optional): Allow higher resolution output on smaller GPUs. Yields same result at the expense of 10% speed. Defaults to False. - resume (bool, optional): When set to True, resume from provided '/' path. Useful if your run was terminated - part of the way through. - batch_size (int, optional): Number of examples per batch fed to pipeline. Increase this until you - run out of VRAM. Defaults to 1. - frame_filename_ext (str, optional): File extension to use when saving/resuming. Update this to - ".jpg" to save or resume generating jpg images instead. Defaults to ".png". - num_steps(int, optional): **Deprecated** Number of interpolation steps. Please use `num_interpolation_steps` instead. - - Returns: - str: Path to video file saved if make_video=True, else None. - """ - - if num_steps: - warn( - ( - "The `num_steps` kwarg of the `stable_diffusion_videos.walk` fn is deprecated in 0.4.0 and will be removed in 0.5.0. " - "Please use `num_interpolation_steps` instead. Setting provided num_interpolation_steps to provided num_steps for now." - ), - DeprecationWarning, - stacklevel=2 - ) - num_interpolation_steps = num_steps - - if upsample: - from .upsampling import PipelineRealESRGAN - - upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan') - - if less_vram: - pipeline.enable_attention_slicing() - - output_path = Path(output_dir) / name - output_path.mkdir(exist_ok=True, parents=True) - prompt_config_path = output_path / 'prompt_config.json' - - if not resume: - # Write prompt info to file in output dir so we can keep track of what we did - prompt_config_path.write_text( - json.dumps( - dict( - prompts=prompts, - seeds=seeds, - num_interpolation_steps=num_interpolation_steps, - name=name, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - do_loop=do_loop, - make_video=make_video, - use_lerp_for_text=use_lerp_for_text, - scheduler=scheduler, - upsample=upsample, - fps=fps, - height=height, - width=width, - ), - indent=2, - sort_keys=False, - ) - ) - else: - # When resuming, we load all available info from existing prompt config, using kwargs passed in where necessary - if not prompt_config_path.exists(): - raise FileNotFoundError(f"You specified resume=True, but no prompt config file was found at {prompt_config_path}") - - data = json.load(open(prompt_config_path)) - prompts = data['prompts'] - seeds = data['seeds'] - # NOTE - num_steps was renamed to num_interpolation_steps. Including it here for backwards compatibility. - num_interpolation_steps = data.get('num_interpolation_steps') or data.get('num_steps') - height = data['height'] if 'height' in data else height - width = data['width'] if 'width' in data else width - guidance_scale = data['guidance_scale'] - eta = data['eta'] - num_inference_steps = data['num_inference_steps'] - do_loop = data['do_loop'] - make_video = data['make_video'] - use_lerp_for_text = data['use_lerp_for_text'] - scheduler = data['scheduler'] - disable_tqdm=disable_tqdm - upsample = data['upsample'] if 'upsample' in data else upsample - fps = data['fps'] if 'fps' in data else fps - - resume_step = int(sorted(output_path.glob(f"frame*{frame_filename_ext}"))[-1].stem[5:]) - print(f"\nResuming {output_path} from step {resume_step}...") - - - if upsample: - from .upsampling import PipelineRealESRGAN - - upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan') - - pipeline.set_progress_bar_config(disable=disable_tqdm) - pipeline.scheduler = SCHEDULERS[scheduler] - - if isinstance(num_interpolation_steps, int): - num_interpolation_steps = [num_interpolation_steps] * (len(prompts)-1) - - assert len(prompts) == len(seeds) == len(num_interpolation_steps) +1 - - first_prompt, *prompts = prompts - embeds_a = pipeline.embed_text(first_prompt) - - first_seed, *seeds = seeds - - latents_a = torch.randn( - (1, pipeline.unet.in_channels, height // 8, width // 8), - device=pipeline.device, - generator=torch.Generator(device=pipeline.device).manual_seed(first_seed), - ) - - if do_loop: - prompts.append(first_prompt) - seeds.append(first_seed) - num_interpolation_steps.append(num_interpolation_steps[0]) - - - frame_index = 0 - total_frame_count = sum(num_interpolation_steps) - for prompt, seed, num_step in zip(prompts, seeds, num_interpolation_steps): - # Text - embeds_b = pipeline.embed_text(prompt) - - # Latent Noise - latents_b = torch.randn( - (1, pipeline.unet.in_channels, height // 8, width // 8), - device=pipeline.device, - generator=torch.Generator(device=pipeline.device).manual_seed(seed), - ) - - latents_batch, embeds_batch = None, None - for i, t in enumerate(np.linspace(0, 1, num_step)): - - frame_filepath = output_path / (f"frame%06d{frame_filename_ext}" % frame_index) - if resume and frame_filepath.is_file(): - frame_index += 1 - continue - - if use_lerp_for_text: - embeds = torch.lerp(embeds_a, embeds_b, float(t)) - else: - embeds = slerp(float(t), embeds_a, embeds_b) - latents = slerp(float(t), latents_a, latents_b) - - embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds]) - latents_batch = latents if latents_batch is None else torch.cat([latents_batch, latents]) - - del embeds - del latents - torch.cuda.empty_cache() - - batch_is_ready = embeds_batch.shape[0] == batch_size or t == 1.0 - if not batch_is_ready: - continue - - do_print_progress = (i == 0) or ((frame_index) % 20 == 0) - if do_print_progress: - print(f"COUNT: {frame_index}/{total_frame_count}") - - with torch.autocast("cuda"): - outputs = pipeline( - latents=latents_batch, - text_embeddings=embeds_batch, - height=height, - width=width, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' if not upsample else 'numpy' - )["sample"] - - del embeds_batch - del latents_batch - torch.cuda.empty_cache() - latents_batch, embeds_batch = None, None - - if upsample: - images = [] - for output in outputs: - images.append(upsampling_pipeline(output)) - else: - images = outputs - for image in images: - frame_filepath = output_path / (f"frame%06d{frame_filename_ext}" % frame_index) - image.save(frame_filepath) - frame_index += 1 - - embeds_a = embeds_b - latents_a = latents_b - - if make_video: - return make_video_ffmpeg(output_path, f"{name}.mp4", fps=fps, frame_filename=f"frame%06d{frame_filename_ext}") - - -if __name__ == "__main__": - import fire - - fire.Fire(walk) diff --git a/stable_diffusion_videos/upsampling.py b/stable_diffusion_videos/upsampling.py index c7f7651..73c1e45 100644 --- a/stable_diffusion_videos/upsampling.py +++ b/stable_diffusion_videos/upsampling.py @@ -3,6 +3,7 @@ import cv2 from PIL import Image from huggingface_hub import hf_hub_download +from torch import nn try: from realesrgan import RealESRGANer @@ -13,20 +14,25 @@ "pip install realesrgan" ) -class PipelineRealESRGAN: + +class RealESRGANModel(nn.Module): def __init__(self, model_path, tile=0, tile_pad=10, pre_pad=0, fp32=False): + super().__init__() + try: + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + except ImportError as e: + raise ImportError( + "You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n" + "pip install realesrgan" + ) + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) self.upsampler = RealESRGANer( - scale=4, - model_path=model_path, - model=model, - tile=tile, - tile_pad=tile_pad, - pre_pad=pre_pad, - half=not fp32 + scale=4, model_path=model_path, model=model, tile=tile, tile_pad=tile_pad, pre_pad=pre_pad, half=not fp32 ) - def __call__(self, image, outscale=4, convert_to_pil=True): + def forward(self, image, outscale=4, convert_to_pil=True): """Upsample an image array or path. Args: @@ -53,7 +59,7 @@ def __call__(self, image, outscale=4, convert_to_pil=True): return image @classmethod - def from_pretrained(cls, model_name_or_path='nateraw/real-esrgan'): + def from_pretrained(cls, model_name_or_path="nateraw/real-esrgan"): """Initialize a pretrained Real-ESRGAN upsampler. Example: @@ -74,18 +80,17 @@ def from_pretrained(cls, model_name_or_path='nateraw/real-esrgan'): if Path(model_name_or_path).exists(): file = model_name_or_path else: - file = hf_hub_download(model_name_or_path, 'RealESRGAN_x4plus.pth') + file = hf_hub_download(model_name_or_path, "RealESRGAN_x4plus.pth") return cls(file) - - def upsample_imagefolder(self, in_dir, out_dir, suffix='out', outfile_ext='.png'): + def upsample_imagefolder(self, in_dir, out_dir, suffix="out", outfile_ext=".png"): in_dir, out_dir = Path(in_dir), Path(out_dir) if not in_dir.exists(): raise FileNotFoundError(f"Provided input directory {in_dir} does not exist") out_dir.mkdir(exist_ok=True, parents=True) - - image_paths = [x for x in in_dir.glob('*') if x.suffix.lower() in ['.png', '.jpg', '.jpeg']] + + image_paths = [x for x in in_dir.glob("*") if x.suffix.lower() in [".png", ".jpg", ".jpeg"]] for image in image_paths: im = self(str(image)) out_filepath = out_dir / (image.stem + suffix + outfile_ext)