diff --git a/.gitignore b/.gitignore index b6e4761..61bac16 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,8 @@ dmypy.json # Pyre type checker .pyre/ + +# Extra stuff to ignore +dreams +images +run.py \ No newline at end of file diff --git a/README.md b/README.md index a6061d0..9df3568 100644 --- a/README.md +++ b/README.md @@ -81,3 +81,45 @@ This work built off of [a script](https://gist.github.com/karpathy/00103b0037c5a You can file any issues/feature requests [here](https://github.com/nateraw/stable-diffusion-videos/issues) Enjoy 🤗 + +## Extras + +### Upsample with Real-ESRGAN + +You can also 4x upsample your images with [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN)! + +First, you'll need to install it... + +``` +pip install realesrgan +``` + +Then, you'll be able to use `upsample=True` in the `walk` function, like this: + +```python +from stable_diffusion_videos import walk + +walk(['a cat', 'a dog'], [234, 345], upsample=True) +``` + +The above may cause you to run out of VRAM. No problem, you can do upsampling separately. + +To upsample an individual image: + +```python +from stable_diffusion_videos import PipelineRealESRGAN + +pipe = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan') +enhanced_image = pipe('your_file.jpg') +``` + +Or, to do a whole folder: + +```python +from stable_diffusion_videos import PipelineRealESRGAN + +pipe = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan') +pipe.enhance_imagefolder('path/to/images/', 'path/to/output_dir') +``` + + diff --git a/faces_upsampled_1.mp4 b/faces_upsampled_1.mp4 new file mode 100644 index 0000000..0cd3be9 Binary files /dev/null and b/faces_upsampled_1.mp4 differ diff --git a/setup.py b/setup.py index f9c0261..d172edb 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,9 @@ def get_version() -> str: with open("requirements.txt", "r") as f: requirements = f.read().splitlines() +extras = {} +extras['realesrgan'] = ['realesrgan==0.2.5.0'] + setup( name="stable_diffusion_videos", version=get_version(), @@ -24,5 +27,6 @@ def get_version() -> str: ), license="Apache", install_requires=requirements, + extras_require=extras, packages=find_packages(), ) diff --git a/stable_diffusion_videos/__init__.py b/stable_diffusion_videos/__init__.py index b163f6e..c368f59 100644 --- a/stable_diffusion_videos/__init__.py +++ b/stable_diffusion_videos/__init__.py @@ -113,6 +113,9 @@ def __dir__(): "walk", "SCHEDULERS", "pipeline", + ], + "upsampling": [ + "PipelineRealESRGAN" ] }, ) diff --git a/stable_diffusion_videos/app.py b/stable_diffusion_videos/app.py index 725e3a2..43aff1d 100644 --- a/stable_diffusion_videos/app.py +++ b/stable_diffusion_videos/app.py @@ -13,16 +13,24 @@ def fn_images( guidance_scale, num_inference_steps, disable_tqdm, + upsample, ): + if upsample: + from .upsampling import PipelineRealESRGAN + + upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan') + pipeline.set_progress_bar_config(disable=disable_tqdm) pipeline.scheduler = SCHEDULERS[scheduler] # klms, default, ddim with torch.autocast("cuda"): - return pipeline( + img = pipeline( prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=torch.Generator(device=pipeline.device).manual_seed(seed), + output_type='pil' if not upsample else 'numpy', )["sample"][0] + return img if not upsample else upsampling_pipeline(img) def fn_videos( @@ -38,6 +46,7 @@ def fn_videos( disable_tqdm, use_lerp_for_text, output_dir, + upsample, ): prompts = [prompt_1, prompt_2] seeds = [seed_1, seed_2] @@ -58,6 +67,7 @@ def fn_videos( name=time.strftime("%Y%m%d-%H%M%S"), scheduler=scheduler, disable_tqdm=disable_tqdm, + upsample=upsample ) return video_path @@ -66,9 +76,9 @@ def fn_videos( fn_videos, inputs=[ gr.Textbox("blueberry spaghetti"), - gr.Slider(0, 1000, 553, step=1), + gr.Number(42, label='Seed 1', precision=0), gr.Textbox("strawberry spaghetti"), - gr.Slider(0, 1000, 234, step=1), + gr.Number(42, label='Seed 2', precision=0), gr.Dropdown(["klms", "ddim", "default"], value="klms"), gr.Slider(0.0, 20.0, 8.5), gr.Slider(1, 200, 50), @@ -82,6 +92,7 @@ def fn_videos( "Folder where outputs will be saved. Each output will be saved in a new folder." ), ), + gr.Checkbox(False), ], outputs=gr.Video(), ) @@ -90,11 +101,12 @@ def fn_videos( fn_images, inputs=[ gr.Textbox("blueberry spaghetti"), - gr.Slider(0, 1000, 553, step=1), + gr.Number(42, label='Seed', precision=0), gr.Dropdown(["klms", "ddim", "default"], value="klms"), gr.Slider(0.0, 20.0, 8.5), gr.Slider(1, 200, 50), gr.Checkbox(False), + gr.Checkbox(False), ], outputs=gr.Image(type="pil"), ) diff --git a/stable_diffusion_videos/stable_diffusion_walk.py b/stable_diffusion_videos/stable_diffusion_walk.py index f2a22fd..e1d57ec 100644 --- a/stable_diffusion_videos/stable_diffusion_walk.py +++ b/stable_diffusion_videos/stable_diffusion_walk.py @@ -86,6 +86,7 @@ def walk( use_lerp_for_text=False, scheduler="klms", # choices: default, ddim, klms disable_tqdm=False, + upsample=False, ): """Generate video frames/a video given a list of prompts and seeds. @@ -105,10 +106,17 @@ def walk( use_lerp_for_text (bool, optional): Use LERP instead of SLERP for text embeddings when walking. Defaults to False. scheduler (str, optional): Which scheduler to use. Defaults to "klms". Choices are "default", "ddim", "klms". disable_tqdm (bool, optional): Whether to turn off the tqdm progress bars. Defaults to False. + upsample(bool, optional): If True, uses Real-ESRGAN to upsample images 4x. Requires it to be installed + which you can do by running: `pip install git+https://github.com/xinntao/Real-ESRGAN.git`. Defaults to False. Returns: str: Path to video file saved if make_video=True, else None. """ + if upsample: + from .upsampling import PipelineRealESRGAN + + upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan') + pipeline.set_progress_bar_config(disable=disable_tqdm) pipeline.scheduler = SCHEDULERS[scheduler] @@ -186,8 +194,12 @@ def walk( guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, + output_type='pil' if not upsample else 'numpy' )["sample"][0] + if upsample: + im = upsampling_pipeline(im) + im.save(output_path / ("frame%06d.jpg" % frame_index)) frame_index += 1 diff --git a/stable_diffusion_videos/upsampling.py b/stable_diffusion_videos/upsampling.py new file mode 100644 index 0000000..8497dd6 --- /dev/null +++ b/stable_diffusion_videos/upsampling.py @@ -0,0 +1,92 @@ +from pathlib import Path + +import cv2 +from PIL import Image +from huggingface_hub import hf_hub_download + +try: + from realesrgan import RealESRGANer + from basicsr.archs.rrdbnet_arch import RRDBNet +except ImportError as e: + raise ImportError( + "You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n" + "pip install realesrgan" + ) + +class PipelineRealESRGAN: + def __init__(self, model_path, tile=0, tile_pad=10, pre_pad=0, fp32=False): + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + self.upsampler = RealESRGANer( + scale=4, + model_path=model_path, + model=model, + tile=tile, + tile_pad=tile_pad, + pre_pad=pre_pad, + half=not fp32 + ) + + def __call__(self, image, outscale=4, convert_to_pil=True): + """Upsample an image array or path. + + Args: + image (Union[np.ndarray, str]): Either a np array or an image path. np array is assumed to be in RGB format, + and we convert it to BGR. + outscale (int, optional): Amount to upscale the image. Defaults to 4. + convert_to_pil (bool, optional): If True, return PIL image. Otherwise, return numpy array (BGR). Defaults to True. + + Returns: + Union[np.ndarray, PIL.Image.Image]: An upsampled version of the input image. + """ + if isinstance(image, (str, Path)): + img = cv2.imread(image, cv2.IMREAD_UNCHANGED) + else: + img = image + img = (img * 255).round().astype("uint8") + img = img[:, :, ::-1] + + image, _ = self.upsampler.enhance(img, outscale=outscale) + + if convert_to_pil: + image = Image.fromarray(image[:, :, ::-1]) + + return image + + @classmethod + def from_pretrained(cls, model_name_or_path='nateraw/real-esrgan'): + """Initialize a pretrained Real-ESRGAN upsampler. + + Example: + ```python + >>> from stable_diffusion_videos import PipelineRealESRGAN + >>> pipe = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan') + >>> im_out = pipe('input_img.jpg') + ``` + + Args: + model_name_or_path (str, optional): The Hugging Face repo ID or path to local model. Defaults to 'nateraw/real-esrgan'. + + Returns: + stable_diffusion_videos.PipelineRealESRGAN: An instance of `PipelineRealESRGAN` instantiated from pretrained model. + """ + # reuploaded form official ones mentioned here: + # https://github.com/xinntao/Real-ESRGAN + if Path(model_name_or_path).exists(): + file = model_name_or_path + else: + file = hf_hub_download(model_name_or_path, 'RealESRGAN_x4plus.pth') + return cls(file) + + + def upsample_imagefolder(self, in_dir, out_dir, suffix='out', outfile_ext='.jpg'): + in_dir, out_dir = Path(in_dir), Path(out_dir) + if not in_dir.exists(): + raise FileNotFoundError(f"Provided input directory {in_dir} does not exist") + + out_dir.mkdir(exist_ok=True, parents=True) + + image_paths = [x for x in in_dir.glob('*') if x.suffix.lower() in ['.png', '.jpg', '.jpeg']] + for image in image_paths: + im = self(str(image)) + out_filepath = out_dir / (image.stem + suffix + outfile_ext) + im.save(out_filepath)