From c8d79f6ad5626e3dadc5a928707211b4fad90861 Mon Sep 17 00:00:00 2001 From: Nathan Raw Date: Fri, 6 Jan 2023 13:58:52 -0500 Subject: [PATCH 1/3] Update stable_diffusion_pipeline.py --- .../stable_diffusion_pipeline.py | 63 +++++++++---------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/stable_diffusion_videos/stable_diffusion_pipeline.py b/stable_diffusion_videos/stable_diffusion_pipeline.py index 9d0e32f..a67ab61 100644 --- a/stable_diffusion_videos/stable_diffusion_pipeline.py +++ b/stable_diffusion_videos/stable_diffusion_pipeline.py @@ -549,9 +549,9 @@ def __call__( def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size): embeds_a = self.embed_text(prompt_a) embeds_b = self.embed_text(prompt_b) - - latents_a = self.init_noise(seed_a, noise_shape) - latents_b = self.init_noise(seed_b, noise_shape) + latents_dtype = embeds_a.dtype + latents_a = self.init_noise(seed_a, noise_shape, latents_dtype) + latents_b = self.init_noise(seed_b, noise_shape, latents_dtype) batch_idx = 0 embeds_batch, noise_batch = None, None @@ -614,24 +614,23 @@ def make_clip_frames( frame_index = skip for _, embeds_batch, noise_batch in batch_generator: - with torch.autocast("cuda"): - outputs = self( - latents=noise_batch, - text_embeddings=embeds_batch, - height=height, - width=width, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type="pil" if not upsample else "numpy", - negative_prompt=negative_prompt, - )["images"] - - for image in outputs: - frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index) - image = image if not upsample else self.upsampler(image) - image.save(frame_filepath) - frame_index += 1 + outputs = self( + latents=noise_batch, + text_embeddings=embeds_batch, + height=height, + width=width, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type="pil" if not upsample else "numpy", + negative_prompt=negative_prompt, + )["images"] + + for image in outputs: + frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index) + image = image if not upsample else self.upsampler(image) + image.save(frame_filepath) + frame_index += 1 def walk( self, @@ -879,19 +878,18 @@ def walk( def embed_text(self, text, negative_prompt=None): """Helper to embed some text""" - with torch.autocast("cuda"): - text_input = self.tokenizer( - text, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - with torch.no_grad(): - embed = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_input = self.tokenizer( + text, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + with torch.no_grad(): + embed = self.text_encoder(text_input.input_ids.to(self.device))[0] return embed - def init_noise(self, seed, noise_shape): + def init_noise(self, seed, noise_shape, dtype): """Helper to initialize noise""" # randn does not exist on mps, so we create noise on CPU here and move it to the device after initialization if self.device.type == "mps": @@ -905,6 +903,7 @@ def init_noise(self, seed, noise_shape): noise_shape, device=self.device, generator=torch.Generator(device=self.device).manual_seed(seed), + dtype=dtype, ) return noise From 3272b16796bd91fcdcad525eb3ff896aa16a0c3c Mon Sep 17 00:00:00 2001 From: Nathan Raw Date: Fri, 6 Jan 2023 14:19:28 -0500 Subject: [PATCH 2/3] update height and width to be optional --- .../stable_diffusion_pipeline.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/stable_diffusion_videos/stable_diffusion_pipeline.py b/stable_diffusion_videos/stable_diffusion_pipeline.py index a67ab61..ac7a60e 100644 --- a/stable_diffusion_videos/stable_diffusion_pipeline.py +++ b/stable_diffusion_videos/stable_diffusion_pipeline.py @@ -299,8 +299,8 @@ def disable_attention_slicing(self): def __call__( self, prompt: Optional[Union[str, List[str]]] = None, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -371,6 +371,9 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -581,8 +584,8 @@ def make_clip_frames( num_inference_steps: int = 50, guidance_scale: float = 7.5, eta: float = 0.0, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, upsample: bool = False, batch_size: int = 1, image_file_ext: str = ".png", @@ -590,6 +593,10 @@ def make_clip_frames( skip: int = 0, negative_prompt: str = None, ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) @@ -644,8 +651,8 @@ def walk( num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, - height: Optional[int] = 512, - width: Optional[int] = 512, + height: Optional[int] = None, + width: Optional[int] = None, upsample: Optional[bool] = False, batch_size: Optional[int] = 1, resume: Optional[bool] = False, @@ -685,9 +692,9 @@ def walk( eta (Optional[float], *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - height (Optional[int], *optional*, defaults to 512): + height (Optional[int], *optional*, defaults to None): height of the images to generate. - width (Optional[int], *optional*, defaults to 512): + width (Optional[int], *optional*, defaults to None): width of the images to generate. upsample (Optional[bool], *optional*, defaults to False): When True, upsamples images with realesrgan. @@ -743,6 +750,9 @@ def walk( Returns: str: The resulting video filepath. This video includes all sub directories' video clips. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor output_path = Path(output_dir) From c2f0423c978413c351f6e87d94e42bf4a8e4e8b4 Mon Sep 17 00:00:00 2001 From: Nathan Raw Date: Fri, 6 Jan 2023 14:21:24 -0500 Subject: [PATCH 3/3] Update requirements.txt --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index b884399..6bfe69b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ transformers>=4.21.0 -diffusers==0.9.0 +diffusers==0.11.1 scipy fire gradio librosa av<10.0.0 -realesrgan==0.2.5.0 \ No newline at end of file +realesrgan==0.2.5.0