Skip to content

Commit

Permalink
Merge pull request #85 from nateraw/new-interface
Browse files Browse the repository at this point in the history
New interface
  • Loading branch information
nateraw authored Oct 20, 2022
2 parents 5051ae1 + 249bdb4 commit 80f8f41
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 97 deletions.
3 changes: 1 addition & 2 deletions stable_diffusion_videos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def __dir__():
submodules=[],
submod_attrs={
"app": [
"interface",
"pipeline",
"Interface",
],
"image_generation": [
"generate_images",
Expand Down
197 changes: 104 additions & 93 deletions stable_diffusion_videos/app.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,115 @@
import time
from pathlib import Path

import gradio as gr
import torch

from .stable_diffusion_pipeline import StableDiffusionWalkPipeline
from .upsampling import RealESRGANModel
from stable_diffusion_videos import generate_images

pipeline = StableDiffusionWalkPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
torch_dtype=torch.float16,
revision="fp16",
).to("cuda")

class Interface:
def __init__(self, pipeline):
self.pipeline = pipeline
self.interface_images = gr.Interface(
self.fn_images,
inputs=[
gr.Textbox("blueberry spaghetti", label='Prompt'),
gr.Slider(1, 24, 1, step=1, label='Batch size'),
gr.Slider(1, 16, 1, step=1, label='# Batches'),
gr.Slider(10, 100, 50, step=1, label='# Inference Steps'),
gr.Slider(5.0, 15.0, 7.5, step=0.5, label='Guidance Scale'),
gr.Slider(512, 1024, 512, step=64, label='Height'),
gr.Slider(512, 1024, 512, step=64, label='Width'),
gr.Checkbox(False, label='Upsample'),
gr.Textbox("./images", label='Output directory to save results to'),
# gr.Checkbox(False, label='Push results to Hugging Face Hub'),
# gr.Textbox("", label='Hugging Face Repo ID to push images to'),
],
outputs=gr.Gallery(),
)

def fn_images(
prompt,
seed,
guidance_scale,
num_inference_steps,
upsample,
):
if upsample:
if getattr(pipeline, "upsampler", None) is None:
pipeline.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan")
pipeline.upsampler.to(pipeline.device)
self.interface_videos = gr.Interface(
self.fn_videos,
inputs=[
gr.Textbox("blueberry spaghetti\nstrawberry spaghetti", lines=2, label='Prompts, separated by new line'),
gr.Textbox("42\n1337", lines=2, label='Seeds, separated by new line'),
gr.Slider(3, 1000, 5, step=1, label='# Interpolation Steps between prompts'),
gr.Slider(3, 60, 5, step=1, label='Output Video FPS'),
gr.Slider(1, 24, 1, step=1, label='Batch size'),
gr.Slider(10, 100, 50, step=1, label='# Inference Steps'),
gr.Slider(5.0, 15.0, 7.5, step=0.5, label='Guidance Scale'),
gr.Slider(512, 1024, 512, step=64, label='Height'),
gr.Slider(512, 1024, 512, step=64, label='Width'),
gr.Checkbox(False, label='Upsample'),
gr.Textbox("./dreams", label='Output directory to save results to'),
],
outputs=gr.Video(),
)
self.interface = gr.TabbedInterface(
[self.interface_images, self.interface_videos],
['Images!', 'Videos!'],
)

with torch.autocast("cuda"):
img = pipeline(
prompt,
def fn_videos(
self,
prompts,
seeds,
num_interpolation_steps,
fps,
batch_size,
num_inference_steps,
guidance_scale,
height,
width,
upsample,
output_dir,
):
prompts = [x.strip() for x in prompts.split('\n') if x.strip()]
seeds = [int(x.strip()) for x in seeds.split('\n') if x.strip()]

return self.pipeline.walk(
prompts=prompts,
seeds=seeds,
num_interpolation_steps=num_interpolation_steps,
fps=fps,
height=height,
width=width,
output_dir=output_dir,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator(device=pipeline.device).manual_seed(seed),
output_type="pil" if not upsample else "numpy",
)["sample"][0]
return img if not upsample else pipeline.upsampler(img)


def fn_videos(
prompt_1,
seed_1,
prompt_2,
seed_2,
guidance_scale,
num_inference_steps,
num_interpolation_steps,
output_dir,
upsample,
):
prompts = [prompt_1, prompt_2]
seeds = [seed_1, seed_2]

prompts = [x for x in prompts if x.strip()]
seeds = seeds[: len(prompts)]

video_path = pipeline.walk(
guidance_scale=guidance_scale,
prompts=prompts,
seeds=seeds,
num_interpolation_steps=num_interpolation_steps,
num_inference_steps=num_inference_steps,
output_dir=output_dir,
name=time.strftime("%Y%m%d-%H%M%S"),
upsample=upsample,
)
return video_path
upsample=upsample,
batch_size=batch_size
)

def fn_images(
self,
prompt,
batch_size,
num_batches,
num_inference_steps,
guidance_scale,
height,
width,
upsample,
output_dir,
repo_id=None,
push_to_hub=False,
):
image_filepaths = generate_images(
self.pipeline,
prompt,
batch_size=batch_size,
num_batches=num_batches,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
output_dir=output_dir,
image_file_ext='.jpg',
upsample=upsample,
height=height,
width=width,
push_to_hub=push_to_hub,
repo_id=repo_id,
create_pr=False,
)
return [(x, Path(x).stem) for x in sorted(image_filepaths)]

interface_videos = gr.Interface(
fn_videos,
inputs=[
gr.Textbox("blueberry spaghetti"),
gr.Number(42, label="Seed 1", precision=0),
gr.Textbox("strawberry spaghetti"),
gr.Number(42, label="Seed 2", precision=0),
gr.Slider(0.0, 20.0, 8.5),
gr.Slider(1, 200, 50),
gr.Slider(3, 240, 10),
gr.Textbox(
"dreams",
placeholder=("Folder where outputs will be saved. Each output will be saved in a new folder."),
),
gr.Checkbox(False),
],
outputs=gr.Video(),
)

interface_images = gr.Interface(
fn_images,
inputs=[
gr.Textbox("blueberry spaghetti"),
gr.Number(42, label="Seed", precision=0),
gr.Slider(0.0, 20.0, 8.5),
gr.Slider(1, 200, 50),
gr.Checkbox(False),
],
outputs=gr.Image(type="pil"),
)

interface = gr.TabbedInterface([interface_images, interface_videos], ["Images!", "Videos!"])

if __name__ == "__main__":
interface.launch(debug=True)
def launch(self, *args, **kwargs):
self.interface.launch(*args, **kwargs)
4 changes: 2 additions & 2 deletions stable_diffusion_videos/stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, ba
torch.cuda.empty_cache()
embeds_batch, noise_batch = None, None

def generate_interpolation_clip(
def make_clip_frames(
self,
prompt_a: str,
prompt_b: str,
Expand Down Expand Up @@ -769,7 +769,7 @@ def walk(
audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
audio_duration = num_step / fps

self.generate_interpolation_clip(
self.make_clip_frames(
prompt_a,
prompt_b,
seed_a,
Expand Down

0 comments on commit 80f8f41

Please sign in to comment.