Skip to content

Commit

Permalink
Merge pull request #17 from nateraw/update-0.3.0
Browse files Browse the repository at this point in the history
Add prompt config, various utils and update to diffusers==0.3.0
  • Loading branch information
nateraw authored Sep 9, 2022
2 parents e091032 + d04de81 commit 00d9f72
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 19 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
transformers
diffusers==0.2.4
diffusers==0.3.0
scipy
fire
gradio
2 changes: 2 additions & 0 deletions stable_diffusion_videos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ def __dir__():
],
"stable_diffusion_pipeline": [
"StableDiffusionPipeline",
"NoCheck",
],
"stable_diffusion_walk": [
"walk",
"SCHEDULERS",
"pipeline",
]
},
)
Expand Down
20 changes: 9 additions & 11 deletions stable_diffusion_videos/stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional, Union

import torch
from diffusers import ModelMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import \
Expand Down Expand Up @@ -190,7 +191,7 @@ def __call__(

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = self.vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down Expand Up @@ -221,15 +222,12 @@ def embed_text(self, text):
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
return embed

def progress_bar(self, iterable):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)

return tqdm(iterable, **self._progress_bar_config)
class NoCheck(ModelMixin):
"""Can be used in place of safety checker. Use responsibly and at your own risk."""
def __init__(self):
super().__init__()
self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3)))

def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
def forward(self, images=None, **kwargs):
return images, [False]
65 changes: 58 additions & 7 deletions stable_diffusion_videos/stable_diffusion_walk.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import json
import subprocess
from pathlib import Path

import numpy as np
import torch
from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler,
PNDMScheduler)
from diffusers import ModelMixin

from .stable_diffusion_pipeline import StableDiffusionPipeline

pipeline = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
torch_dtype=torch.float16,
revision="fp16",
).to("cuda")

default_scheduler = PNDMScheduler(
Expand Down Expand Up @@ -57,6 +60,16 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
return v2


def make_video_ffmpeg(frame_dir, output_file_name='output.mp4', frame_filename="frame%06d.jpg", fps=30):
frame_ref_path = str(frame_dir / frame_filename)
video_path = str(frame_dir / output_file_name)
subprocess.call(
f"ffmpeg -r {fps} -i {frame_ref_path} -vcodec libx264 -crf 10 -pix_fmt yuv420p"
f" {video_path}".split()
)
return video_path


def walk(
prompts=["blueberry spaghetti", "strawberry spaghetti"],
seeds=[42, 123],
Expand All @@ -74,13 +87,57 @@ def walk(
scheduler="klms", # choices: default, ddim, klms
disable_tqdm=False,
):
"""Generate video frames/a video given a list of prompts and seeds.
Args:
prompts (List[str], optional): List of . Defaults to ["blueberry spaghetti", "strawberry spaghetti"].
seeds (List[int], optional): List of random seeds corresponding to given prompts.
num_steps (int, optional): Number of steps to walk. Increase this value to 60-200 for good results. Defaults to 5.
output_dir (str, optional): Root dir where images will be saved. Defaults to "dreams".
name (str, optional): Sub directory of output_dir to save this run's files. Defaults to "berry_good_spaghetti".
height (int, optional): Height of image to generate. Defaults to 512.
width (int, optional): Width of image to generate. Defaults to 512.
guidance_scale (float, optional): Higher = more adherance to prompt. Lower = let model take the wheel. Defaults to 7.5.
eta (float, optional): ETA. Defaults to 0.0.
num_inference_steps (int, optional): Number of diffusion steps. Defaults to 50.
do_loop (bool, optional): Whether to loop from last prompt back to first. Defaults to False.
make_video (bool, optional): Whether to make a video or just save the images. Defaults to False.
use_lerp_for_text (bool, optional): Use LERP instead of SLERP for text embeddings when walking. Defaults to False.
scheduler (str, optional): Which scheduler to use. Defaults to "klms". Choices are "default", "ddim", "klms".
disable_tqdm (bool, optional): Whether to turn off the tqdm progress bars. Defaults to False.
Returns:
str: Path to video file saved if make_video=True, else None.
"""
pipeline.set_progress_bar_config(disable=disable_tqdm)

pipeline.scheduler = SCHEDULERS[scheduler]

output_path = Path(output_dir) / name
output_path.mkdir(exist_ok=True, parents=True)

# Write prompt info to file in output dir so we can keep track of what we did
prompt_config_path = output_path / 'prompt_config.json'
prompt_config_path.write_text(
json.dumps(
dict(
prompts=prompts,
seeds=seeds,
num_steps=num_steps,
name=name,
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
do_loop=do_loop,
make_video=make_video,
use_lerp_for_text=use_lerp_for_text,
scheduler=scheduler
),
indent=2,
sort_keys=False,
)
)

assert len(prompts) == len(seeds)

first_prompt, *prompts = prompts
Expand Down Expand Up @@ -138,13 +195,7 @@ def walk(
latents_a = latents_b

if make_video:
frame_ref_path = str(output_path / "frame%06d.jpg")
video_path = str(output_path / f"{name}.mp4")
subprocess.call(
f"ffmpeg -r 30 -i {frame_ref_path} -vcodec libx264 -crf 10 -pix_fmt yuv420p"
f" {video_path}".split()
)
return video_path
return make_video_ffmpeg(output_path, f"{name}.mp4")


if __name__ == "__main__":
Expand Down

0 comments on commit 00d9f72

Please sign in to comment.