Skip to content

Commit

Permalink
Merge pull request #21 from nateraw/real-esrgan
Browse files Browse the repository at this point in the history
Add upsampling with Real-ESRGAN
  • Loading branch information
nateraw authored Sep 12, 2022
2 parents 4a3b971 + f93f06b commit 60f6e12
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 4 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,8 @@ dmypy.json

# Pyre type checker
.pyre/

# Extra stuff to ignore
dreams
images
run.py
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,45 @@ This work built off of [a script](https://gist.github.com/karpathy/00103b0037c5a
You can file any issues/feature requests [here](https://github.com/nateraw/stable-diffusion-videos/issues)

Enjoy 🤗

## Extras

### Upsample with Real-ESRGAN

You can also 4x upsample your images with [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN)!

First, you'll need to install it...

```
pip install realesrgan
```

Then, you'll be able to use `upsample=True` in the `walk` function, like this:

```python
from stable_diffusion_videos import walk

walk(['a cat', 'a dog'], [234, 345], upsample=True)
```

The above may cause you to run out of VRAM. No problem, you can do upsampling separately.

To upsample an individual image:

```python
from stable_diffusion_videos import PipelineRealESRGAN

pipe = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')
enhanced_image = pipe('your_file.jpg')
```

Or, to do a whole folder:

```python
from stable_diffusion_videos import PipelineRealESRGAN

pipe = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')
pipe.enhance_imagefolder('path/to/images/', 'path/to/output_dir')
```


Binary file added faces_upsampled_1.mp4
Binary file not shown.
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def get_version() -> str:
with open("requirements.txt", "r") as f:
requirements = f.read().splitlines()

extras = {}
extras['realesrgan'] = ['realesrgan==0.2.5.0']

setup(
name="stable_diffusion_videos",
version=get_version(),
Expand All @@ -24,5 +27,6 @@ def get_version() -> str:
),
license="Apache",
install_requires=requirements,
extras_require=extras,
packages=find_packages(),
)
3 changes: 3 additions & 0 deletions stable_diffusion_videos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def __dir__():
"walk",
"SCHEDULERS",
"pipeline",
],
"upsampling": [
"PipelineRealESRGAN"
]
},
)
Expand Down
20 changes: 16 additions & 4 deletions stable_diffusion_videos/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,24 @@ def fn_images(
guidance_scale,
num_inference_steps,
disable_tqdm,
upsample,
):
if upsample:
from .upsampling import PipelineRealESRGAN

upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')

pipeline.set_progress_bar_config(disable=disable_tqdm)
pipeline.scheduler = SCHEDULERS[scheduler] # klms, default, ddim
with torch.autocast("cuda"):
return pipeline(
img = pipeline(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator(device=pipeline.device).manual_seed(seed),
output_type='pil' if not upsample else 'numpy',
)["sample"][0]
return img if not upsample else upsampling_pipeline(img)


def fn_videos(
Expand All @@ -38,6 +46,7 @@ def fn_videos(
disable_tqdm,
use_lerp_for_text,
output_dir,
upsample,
):
prompts = [prompt_1, prompt_2]
seeds = [seed_1, seed_2]
Expand All @@ -58,6 +67,7 @@ def fn_videos(
name=time.strftime("%Y%m%d-%H%M%S"),
scheduler=scheduler,
disable_tqdm=disable_tqdm,
upsample=upsample
)
return video_path

Expand All @@ -66,9 +76,9 @@ def fn_videos(
fn_videos,
inputs=[
gr.Textbox("blueberry spaghetti"),
gr.Slider(0, 1000, 553, step=1),
gr.Number(42, label='Seed 1', precision=0),
gr.Textbox("strawberry spaghetti"),
gr.Slider(0, 1000, 234, step=1),
gr.Number(42, label='Seed 2', precision=0),
gr.Dropdown(["klms", "ddim", "default"], value="klms"),
gr.Slider(0.0, 20.0, 8.5),
gr.Slider(1, 200, 50),
Expand All @@ -82,6 +92,7 @@ def fn_videos(
"Folder where outputs will be saved. Each output will be saved in a new folder."
),
),
gr.Checkbox(False),
],
outputs=gr.Video(),
)
Expand All @@ -90,11 +101,12 @@ def fn_videos(
fn_images,
inputs=[
gr.Textbox("blueberry spaghetti"),
gr.Slider(0, 1000, 553, step=1),
gr.Number(42, label='Seed', precision=0),
gr.Dropdown(["klms", "ddim", "default"], value="klms"),
gr.Slider(0.0, 20.0, 8.5),
gr.Slider(1, 200, 50),
gr.Checkbox(False),
gr.Checkbox(False),
],
outputs=gr.Image(type="pil"),
)
Expand Down
12 changes: 12 additions & 0 deletions stable_diffusion_videos/stable_diffusion_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def walk(
use_lerp_for_text=False,
scheduler="klms", # choices: default, ddim, klms
disable_tqdm=False,
upsample=False,
):
"""Generate video frames/a video given a list of prompts and seeds.
Expand All @@ -105,10 +106,17 @@ def walk(
use_lerp_for_text (bool, optional): Use LERP instead of SLERP for text embeddings when walking. Defaults to False.
scheduler (str, optional): Which scheduler to use. Defaults to "klms". Choices are "default", "ddim", "klms".
disable_tqdm (bool, optional): Whether to turn off the tqdm progress bars. Defaults to False.
upsample(bool, optional): If True, uses Real-ESRGAN to upsample images 4x. Requires it to be installed
which you can do by running: `pip install git+https://github.com/xinntao/Real-ESRGAN.git`. Defaults to False.
Returns:
str: Path to video file saved if make_video=True, else None.
"""
if upsample:
from .upsampling import PipelineRealESRGAN

upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')

pipeline.set_progress_bar_config(disable=disable_tqdm)

pipeline.scheduler = SCHEDULERS[scheduler]
Expand Down Expand Up @@ -186,8 +194,12 @@ def walk(
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type='pil' if not upsample else 'numpy'
)["sample"][0]

if upsample:
im = upsampling_pipeline(im)

im.save(output_path / ("frame%06d.jpg" % frame_index))
frame_index += 1

Expand Down
92 changes: 92 additions & 0 deletions stable_diffusion_videos/upsampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from pathlib import Path

import cv2
from PIL import Image
from huggingface_hub import hf_hub_download

try:
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
except ImportError as e:
raise ImportError(
"You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
"pip install realesrgan"
)

class PipelineRealESRGAN:
def __init__(self, model_path, tile=0, tile_pad=10, pre_pad=0, fp32=False):
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
self.upsampler = RealESRGANer(
scale=4,
model_path=model_path,
model=model,
tile=tile,
tile_pad=tile_pad,
pre_pad=pre_pad,
half=not fp32
)

def __call__(self, image, outscale=4, convert_to_pil=True):
"""Upsample an image array or path.
Args:
image (Union[np.ndarray, str]): Either a np array or an image path. np array is assumed to be in RGB format,
and we convert it to BGR.
outscale (int, optional): Amount to upscale the image. Defaults to 4.
convert_to_pil (bool, optional): If True, return PIL image. Otherwise, return numpy array (BGR). Defaults to True.
Returns:
Union[np.ndarray, PIL.Image.Image]: An upsampled version of the input image.
"""
if isinstance(image, (str, Path)):
img = cv2.imread(image, cv2.IMREAD_UNCHANGED)
else:
img = image
img = (img * 255).round().astype("uint8")
img = img[:, :, ::-1]

image, _ = self.upsampler.enhance(img, outscale=outscale)

if convert_to_pil:
image = Image.fromarray(image[:, :, ::-1])

return image

@classmethod
def from_pretrained(cls, model_name_or_path='nateraw/real-esrgan'):
"""Initialize a pretrained Real-ESRGAN upsampler.
Example:
```python
>>> from stable_diffusion_videos import PipelineRealESRGAN
>>> pipe = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')
>>> im_out = pipe('input_img.jpg')
```
Args:
model_name_or_path (str, optional): The Hugging Face repo ID or path to local model. Defaults to 'nateraw/real-esrgan'.
Returns:
stable_diffusion_videos.PipelineRealESRGAN: An instance of `PipelineRealESRGAN` instantiated from pretrained model.
"""
# reuploaded form official ones mentioned here:
# https://github.com/xinntao/Real-ESRGAN
if Path(model_name_or_path).exists():
file = model_name_or_path
else:
file = hf_hub_download(model_name_or_path, 'RealESRGAN_x4plus.pth')
return cls(file)


def upsample_imagefolder(self, in_dir, out_dir, suffix='out', outfile_ext='.jpg'):
in_dir, out_dir = Path(in_dir), Path(out_dir)
if not in_dir.exists():
raise FileNotFoundError(f"Provided input directory {in_dir} does not exist")

out_dir.mkdir(exist_ok=True, parents=True)

image_paths = [x for x in in_dir.glob('*') if x.suffix.lower() in ['.png', '.jpg', '.jpeg']]
for image in image_paths:
im = self(str(image))
out_filepath = out_dir / (image.stem + suffix + outfile_ext)
im.save(out_filepath)

0 comments on commit 60f6e12

Please sign in to comment.