diff --git a/stable_diffusion_videos/__init__.py b/stable_diffusion_videos/__init__.py index ba74814..6814d7e 100644 --- a/stable_diffusion_videos/__init__.py +++ b/stable_diffusion_videos/__init__.py @@ -102,6 +102,10 @@ def __dir__(): "interface", "pipeline", ], + "image_generation": [ + "generate_images", + "upload_folder_chunked", + ], "stable_diffusion_pipeline": [ "StableDiffusionWalkPipeline", "NoCheck", @@ -112,4 +116,4 @@ def __dir__(): }, ) -__version__ = "0.5.0" +__version__ = "0.5.1dev" diff --git a/stable_diffusion_videos/image_generation.py b/stable_diffusion_videos/image_generation.py new file mode 100644 index 0000000..c212cb1 --- /dev/null +++ b/stable_diffusion_videos/image_generation.py @@ -0,0 +1,193 @@ +import time +import random +import json +from pathlib import Path +from huggingface_hub import CommitOperationAdd, create_repo, create_commit +import torch + +from .upsampling import RealESRGANModel + + +def get_all_files(root: Path): + dirs = [root] + while len(dirs) > 0: + dir = dirs.pop() + for candidate in dir.iterdir(): + if candidate.is_file(): + yield candidate + if candidate.is_dir(): + dirs.append(candidate) + + +def get_groups_of_n(n: int, iterator): + assert n > 1 + buffer = [] + for elt in iterator: + if len(buffer) == n: + yield buffer + buffer = [] + buffer.append(elt) + if len(buffer) != 0: + yield buffer + + +def upload_folder_chunked( + repo_id: str, + upload_dir: Path, + n: int = 100, + private: bool = False, + create_pr: bool = False, +): + """Upload a folder to the Hugging Face Hub in chunks of n files at a time. + + Args: + repo_id (str): The repo id to upload to. + upload_dir (Path): The directory to upload. + n (int, *optional*, defaults to 100): The number of files to upload at a time. + private (bool, *optional*): Whether to upload the repo as private. + create_pr (bool, *optional*): Whether to create a PR after uploading instead of commiting directly. + """ + + url = create_repo(repo_id, exist_ok=True, private=private, repo_type='dataset') + print(f"Uploading files to: {url}") + + root = Path(upload_dir) + if not root.exists(): + raise ValueError(f"Upload directory {root} does not exist.") + + for i, file_paths in enumerate(get_groups_of_n(n, get_all_files(root))): + print(f"Committing {file_paths}") + operations = [ + CommitOperationAdd(path_in_repo=f"{file_path.parent.name}/{file_path.name}", path_or_fileobj=str(file_path)) + for file_path in file_paths + ] + create_commit( + repo_id=repo_id, + operations=operations, + commit_message=f"Upload part {i}", + repo_type='dataset', + create_pr=create_pr, + ) + +def generate_input_batches(pipeline, prompts, seeds, batch_size, height, width): + if len(prompts) != len(seeds): + raise ValueError("Number of prompts and seeds must be equal.") + + embeds_batch, noise_batch = None, None + batch_idx = 0 + for i, (prompt, seed) in enumerate(zip(prompts, seeds)): + embeds = pipeline.embed_text(prompt) + noise = torch.randn( + (1, pipeline.unet.in_channels, height // 8, width // 8), + device=pipeline.device, + generator=torch.Generator(device=pipeline.device).manual_seed(seed), + ) + embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds]) + noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise]) + batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == len(prompts) + if not batch_is_ready: + continue + yield batch_idx, embeds_batch, noise_batch + batch_idx += 1 + del embeds_batch, noise_batch + torch.cuda.empty_cache() + embeds_batch, noise_batch = None, None + + +def generate_images( + pipeline, + prompt, + batch_size=1, + num_batches=1, + num_inference_steps=50, + guidance_scale=7.5, + output_dir="./images", + image_file_ext='.jpg', + upsample=False, + height=512, + width=512, + eta=0.0, + push_to_hub=False, + repo_id=None, + private=False, + create_pr=False, +): + """Generate images using the StableDiffusion pipeline. + + Args: + pipeline (StableDiffusionWalkPipeline): The StableDiffusion pipeline instance. + prompt (str): The prompt to use for the image generation. + batch_size (int, *optional*, defaults to 1): The batch size to use for image generation. + num_batches (int, *optional*, defaults to 1): The number of batches to generate. + num_inference_steps (int, *optional*, defaults to 50): The number of inference steps to take. + guidance_scale (float, *optional*, defaults to 7.5): The guidance scale to use for image generation. + output_dir (str, *optional*, defaults to "./images"): The output directory to save the images to. + image_file_ext (str, *optional*, defaults to '.jpg'): The image file extension to use. + upsample (bool, *optional*, defaults to False): Whether to upsample the images. + height (int, *optional*, defaults to 512): The height of the images to generate. + width (int, *optional*, defaults to 512): The width of the images to generate. + eta (float, *optional*, defaults to 0.0): The eta parameter to use for image generation. + push_to_hub (bool, *optional*, defaults to False): Whether to push the generated images to the Hugging Face Hub. + repo_id (str, *optional*): The repo id to push the images to. + private (bool, *optional*): Whether to push the repo as private. + create_pr (bool, *optional*): Whether to create a PR after pushing instead of commiting directly. + """ + if push_to_hub: + if repo_id is None: + raise ValueError("Must provide repo_id if push_to_hub is True.") + + name = time.strftime("%Y%m%d-%H%M%S") + save_path = Path(output_dir) / name + save_path.mkdir(exist_ok=False, parents=True) + prompt_config_path = save_path / "prompt_config.json" + + num_images = batch_size * num_batches + seeds = [random.choice(list(range(0, 9999999))) for _ in range(num_images)] + + if upsample: + if getattr(pipeline, 'upsampler', None) is None: + pipeline.upsampler = RealESRGANModel.from_pretrained('nateraw/real-esrgan') + pipeline.upsampler.to(pipeline.device) + + cfg = dict( + prompt=prompt, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + upsample=upsample, + height=height, + width=width, + scheduler=dict(pipeline.scheduler.config), + diffusers_version=pipeline.config._diffusers_version, + device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'unknown', + ) + prompt_config_path.write_text(json.dumps(cfg, indent=2, sort_keys=False)) + + frame_index = 0 + for batch_idx, embeds, noise in generate_input_batches(pipeline, [prompt] * num_images, seeds, batch_size, height, width): + print(f"Generating batch {batch_idx}") + + with torch.autocast('cuda'): + outputs = pipeline( + text_embeddings=embeds, + latents=noise, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + eta=eta, + height=height, + width=width, + )['sample'] + if upsample: + images = [] + for output in outputs: + images.append(pipeline.upsampler(output)) + else: + images = outputs + + for image in images: + frame_filepath = save_path / f"{seeds[frame_index]}{image_file_ext}" + image.save(frame_filepath) + frame_index += 1 + + if push_to_hub: + upload_folder_chunked(repo_id, save_path, private=private, create_pr=create_pr)