diff --git a/.gitignore b/.gitignore index 1b87291..38a8e6f 100644 --- a/.gitignore +++ b/.gitignore @@ -132,4 +132,5 @@ dmypy.json dreams images run.py -examples \ No newline at end of file +examples +test_outputs \ No newline at end of file diff --git a/stable_diffusion_videos/stable_diffusion_pipeline.py b/stable_diffusion_videos/stable_diffusion_pipeline.py index 742aaa4..08a86fc 100644 --- a/stable_diffusion_videos/stable_diffusion_pipeline.py +++ b/stable_diffusion_videos/stable_diffusion_pipeline.py @@ -654,7 +654,7 @@ def walk( if not resume and isinstance(num_interpolation_steps, int): num_interpolation_steps = [num_interpolation_steps] * (len(prompts) - 1) - if not resume and audio_filepath: + if not resume: audio_start_sec = audio_start_sec or 0 # Save/reload prompt config @@ -719,6 +719,9 @@ def walk( continue print(f"Resuming {save_path.name} from frame {skip}") + audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps + audio_duration = num_step / fps + self.generate_interpolation_clip( prompt_a, prompt_b, @@ -736,8 +739,8 @@ def walk( skip=skip, T=get_timesteps_arr( audio_filepath, - offset=audio_start_sec + (i * num_step / fps), - duration=num_step / fps, + offset=audio_offset, + duration=audio_duration, fps=fps, margin=(1.0, 5.0), ) @@ -750,8 +753,8 @@ def walk( fps=fps, output_filepath=step_output_filepath, glob_pattern=f"*{image_file_ext}", - audio_offset=audio_start_sec + (i * num_step / fps) if audio_start_sec else 0, - audio_duration=num_step / fps, + audio_offset=audio_offset, + audio_duration=audio_duration, sr=44100, ) diff --git a/tests/samples/choice.wav b/tests/samples/choice.wav new file mode 100644 index 0000000..29d3a59 Binary files /dev/null and b/tests/samples/choice.wav differ diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..1ceb7d7 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,82 @@ +"""Tests require GPU, so they will not be running on CI (unless someone +wants to figure that out for me). + +We'll run these locally before pushing to the repo, or at the very least +before making a release. +""" + +from stable_diffusion_videos import NoCheck, StableDiffusionWalkPipeline +import torch +from pathlib import Path +from shutil import rmtree + +import pytest + + +TEST_OUTPUT_ROOT = "test_outputs" +SAMPLES_DIR = Path(__file__).parent / "samples" + +@pytest.fixture +def pipeline(scope="session"): + pipe = StableDiffusionWalkPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + use_auth_token=True, + torch_dtype=torch.float16, + revision="fp16", + ).to('cuda') + pipe.safety_checker = NoCheck().cuda() + return pipe + + +@pytest.fixture +def run_name(request): + fn_name = request.node.name.lstrip('test_') + output_path = Path(TEST_OUTPUT_ROOT) / fn_name + if output_path.exists(): + rmtree(output_path) + # We could instead yield here and rm the dir after its written. + # However, I like being able to view the files locally to see if they look right. + return fn_name + + +def test_walk_basic(pipeline, run_name): + video_path = pipeline.walk( + ['a cat', 'a dog', 'a horse'], + seeds=[42, 1337, 2022], + num_interpolation_steps=[3, 3], + output_dir=TEST_OUTPUT_ROOT, + name=run_name, + fps=3, + ) + assert Path(video_path).exists(), "Video file was not created" + + +def test_walk_with_audio(pipeline, run_name): + fps = 6 + audio_offsets = [2, 4, 5, 8] + num_interpolation_steps = [(b - a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])] + video_path = pipeline.walk( + ['a cat', 'a dog', 'a horse', 'a cow'], + seeds=[42, 1337, 4321, 1234], + num_interpolation_steps=num_interpolation_steps, + output_dir=TEST_OUTPUT_ROOT, + name=run_name, + fps=fps, + audio_filepath=str(Path(SAMPLES_DIR) / 'choice.wav'), + audio_start_sec=audio_offsets[0], + batch_size=16, + ) + assert Path(video_path).exists(), "Video file was not created" + + +def test_walk_with_upsampler(pipeline, run_name): + video_path = pipeline.walk( + ['a cat', 'a dog', 'a horse'], + seeds=[42, 1337, 2022], + num_interpolation_steps=[3, 3], + output_dir=TEST_OUTPUT_ROOT, + name=run_name, + fps=3, + upsample=True, + ) + assert Path(video_path).exists(), "Video file was not created"