Skip to content

Commit

Permalink
Merge pull request #84 from nateraw/cleanup-and-fix-tests
Browse files Browse the repository at this point in the history
Cleanup and fix tests
  • Loading branch information
nateraw authored Oct 20, 2022
2 parents 63a23fb + cff9b90 commit 5051ae1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
34 changes: 21 additions & 13 deletions stable_diffusion_videos/stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import json

import torch
from diffusers import ModelMixin
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -128,7 +127,6 @@ def make_video_pyav(
frame = pil_to_tensor(Image.open(img)).unsqueeze(0)
frames = frame if frames is None else torch.cat([frames, frame])
else:

frames = frames_or_frame_dir

# TCHW -> THWC
Expand Down Expand Up @@ -505,16 +503,9 @@ def __call__(
def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size):
embeds_a = self.embed_text(prompt_a)
embeds_b = self.embed_text(prompt_b)
latents_a = torch.randn(
noise_shape,
device=self.device,
generator=torch.Generator(device=self.device).manual_seed(seed_a),
)
latents_b = torch.randn(
noise_shape,
device=self.device,
generator=torch.Generator(device=self.device).manual_seed(seed_b),
)

latents_a = self.init_noise(seed_a, noise_shape)
latents_b = self.init_noise(seed_b, noise_shape)

batch_idx = 0
embeds_batch, noise_batch = None, None
Expand Down Expand Up @@ -665,7 +656,7 @@ def walk(
For example, if you provide the following prompts and seeds:
```
prompts = ['a', 'b', 'c']
prompts = ['a dog', 'a cat', 'a bird']
seeds = [1, 2, 3]
num_interpolation_steps = 5
output_dir = 'output_dir'
Expand Down Expand Up @@ -839,6 +830,23 @@ def embed_text(self, text):
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
return embed

def init_noise(self, seed, noise_shape):
"""Helper to initialize noise"""
# randn does not exist on mps, so we create noise on CPU here and move it to the device after initialization
if self.device.type == "mps":
noise = torch.randn(
noise_shape,
device='cpu',
generator=torch.Generator(device='cpu').manual_seed(seed),
).to(self.device)
else:
noise = torch.randn(
noise_shape,
device=self.device,
generator=torch.Generator(device=self.device).manual_seed(seed),
)
return noise

@classmethod
def from_pretrained(cls, *args, tiled=False, **kwargs):
"""Same as diffusers `from_pretrained` but with tiled option, which makes images tilable"""
Expand Down
5 changes: 2 additions & 3 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
before making a release.
"""

from stable_diffusion_videos import NoCheck, StableDiffusionWalkPipeline
from stable_diffusion_videos import StableDiffusionWalkPipeline
import torch
from pathlib import Path
from shutil import rmtree
Expand All @@ -20,11 +20,10 @@
def pipeline(scope="session"):
pipe = StableDiffusionWalkPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
torch_dtype=torch.float16,
revision="fp16",
safety_checker=None,
).to('cuda')
pipe.safety_checker = NoCheck().cuda()
return pipe


Expand Down

0 comments on commit 5051ae1

Please sign in to comment.