Skip to content

Commit

Permalink
Merge pull request #76 from nateraw/fix-audio-alignment
Browse files Browse the repository at this point in the history
Fix audio alignment and add basic tests
  • Loading branch information
nateraw authored Oct 11, 2022
2 parents 90039cf + 6a949ec commit 5eb000e
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ dmypy.json
dreams
images
run.py
examples
examples
test_outputs
13 changes: 8 additions & 5 deletions stable_diffusion_videos/stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def walk(
if not resume and isinstance(num_interpolation_steps, int):
num_interpolation_steps = [num_interpolation_steps] * (len(prompts) - 1)

if not resume and audio_filepath:
if not resume:
audio_start_sec = audio_start_sec or 0

# Save/reload prompt config
Expand Down Expand Up @@ -719,6 +719,9 @@ def walk(
continue
print(f"Resuming {save_path.name} from frame {skip}")

audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
audio_duration = num_step / fps

self.generate_interpolation_clip(
prompt_a,
prompt_b,
Expand All @@ -736,8 +739,8 @@ def walk(
skip=skip,
T=get_timesteps_arr(
audio_filepath,
offset=audio_start_sec + (i * num_step / fps),
duration=num_step / fps,
offset=audio_offset,
duration=audio_duration,
fps=fps,
margin=(1.0, 5.0),
)
Expand All @@ -750,8 +753,8 @@ def walk(
fps=fps,
output_filepath=step_output_filepath,
glob_pattern=f"*{image_file_ext}",
audio_offset=audio_start_sec + (i * num_step / fps) if audio_start_sec else 0,
audio_duration=num_step / fps,
audio_offset=audio_offset,
audio_duration=audio_duration,
sr=44100,
)

Expand Down
Binary file added tests/samples/choice.wav
Binary file not shown.
82 changes: 82 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Tests require GPU, so they will not be running on CI (unless someone
wants to figure that out for me).
We'll run these locally before pushing to the repo, or at the very least
before making a release.
"""

from stable_diffusion_videos import NoCheck, StableDiffusionWalkPipeline
import torch
from pathlib import Path
from shutil import rmtree

import pytest


TEST_OUTPUT_ROOT = "test_outputs"
SAMPLES_DIR = Path(__file__).parent / "samples"

@pytest.fixture
def pipeline(scope="session"):
pipe = StableDiffusionWalkPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
torch_dtype=torch.float16,
revision="fp16",
).to('cuda')
pipe.safety_checker = NoCheck().cuda()
return pipe


@pytest.fixture
def run_name(request):
fn_name = request.node.name.lstrip('test_')
output_path = Path(TEST_OUTPUT_ROOT) / fn_name
if output_path.exists():
rmtree(output_path)
# We could instead yield here and rm the dir after its written.
# However, I like being able to view the files locally to see if they look right.
return fn_name


def test_walk_basic(pipeline, run_name):
video_path = pipeline.walk(
['a cat', 'a dog', 'a horse'],
seeds=[42, 1337, 2022],
num_interpolation_steps=[3, 3],
output_dir=TEST_OUTPUT_ROOT,
name=run_name,
fps=3,
)
assert Path(video_path).exists(), "Video file was not created"


def test_walk_with_audio(pipeline, run_name):
fps = 6
audio_offsets = [2, 4, 5, 8]
num_interpolation_steps = [(b - a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])]
video_path = pipeline.walk(
['a cat', 'a dog', 'a horse', 'a cow'],
seeds=[42, 1337, 4321, 1234],
num_interpolation_steps=num_interpolation_steps,
output_dir=TEST_OUTPUT_ROOT,
name=run_name,
fps=fps,
audio_filepath=str(Path(SAMPLES_DIR) / 'choice.wav'),
audio_start_sec=audio_offsets[0],
batch_size=16,
)
assert Path(video_path).exists(), "Video file was not created"


def test_walk_with_upsampler(pipeline, run_name):
video_path = pipeline.walk(
['a cat', 'a dog', 'a horse'],
seeds=[42, 1337, 2022],
num_interpolation_steps=[3, 3],
output_dir=TEST_OUTPUT_ROOT,
name=run_name,
fps=3,
upsample=True,
)
assert Path(video_path).exists(), "Video file was not created"

0 comments on commit 5eb000e

Please sign in to comment.