Skip to content

Commit

Permalink
✅ add some very basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Oct 11, 2022
1 parent bf77bd9 commit 6a949ec
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ dmypy.json
dreams
images
run.py
examples
examples
test_outputs
Binary file added tests/samples/choice.wav
Binary file not shown.
82 changes: 82 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Tests require GPU, so they will not be running on CI (unless someone
wants to figure that out for me).
We'll run these locally before pushing to the repo, or at the very least
before making a release.
"""

from stable_diffusion_videos import NoCheck, StableDiffusionWalkPipeline
import torch
from pathlib import Path
from shutil import rmtree

import pytest


TEST_OUTPUT_ROOT = "test_outputs"
SAMPLES_DIR = Path(__file__).parent / "samples"

@pytest.fixture
def pipeline(scope="session"):
pipe = StableDiffusionWalkPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
torch_dtype=torch.float16,
revision="fp16",
).to('cuda')
pipe.safety_checker = NoCheck().cuda()
return pipe


@pytest.fixture
def run_name(request):
fn_name = request.node.name.lstrip('test_')
output_path = Path(TEST_OUTPUT_ROOT) / fn_name
if output_path.exists():
rmtree(output_path)
# We could instead yield here and rm the dir after its written.
# However, I like being able to view the files locally to see if they look right.
return fn_name


def test_walk_basic(pipeline, run_name):
video_path = pipeline.walk(
['a cat', 'a dog', 'a horse'],
seeds=[42, 1337, 2022],
num_interpolation_steps=[3, 3],
output_dir=TEST_OUTPUT_ROOT,
name=run_name,
fps=3,
)
assert Path(video_path).exists(), "Video file was not created"


def test_walk_with_audio(pipeline, run_name):
fps = 6
audio_offsets = [2, 4, 5, 8]
num_interpolation_steps = [(b - a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])]
video_path = pipeline.walk(
['a cat', 'a dog', 'a horse', 'a cow'],
seeds=[42, 1337, 4321, 1234],
num_interpolation_steps=num_interpolation_steps,
output_dir=TEST_OUTPUT_ROOT,
name=run_name,
fps=fps,
audio_filepath=str(Path(SAMPLES_DIR) / 'choice.wav'),
audio_start_sec=audio_offsets[0],
batch_size=16,
)
assert Path(video_path).exists(), "Video file was not created"


def test_walk_with_upsampler(pipeline, run_name):
video_path = pipeline.walk(
['a cat', 'a dog', 'a horse'],
seeds=[42, 1337, 2022],
num_interpolation_steps=[3, 3],
output_dir=TEST_OUTPUT_ROOT,
name=run_name,
fps=3,
upsample=True,
)
assert Path(video_path).exists(), "Video file was not created"

0 comments on commit 6a949ec

Please sign in to comment.