-
Notifications
You must be signed in to change notification settings - Fork 435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CLI script #153
base: main
Are you sure you want to change the base?
Add CLI script #153
Changes from all commits
b25eb57
3342302
961a8db
21542ae
6c261f7
ce34e15
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,154 @@ | ||||
import argparse | ||||
import random | ||||
|
||||
import torch | ||||
import yaml | ||||
|
||||
from diffusers import DPMSolverMultistepScheduler | ||||
from stable_diffusion_videos import StableDiffusionWalkPipeline | ||||
|
||||
|
||||
def init_arg_parser(): | ||||
parser = argparse.ArgumentParser( | ||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||||
) | ||||
|
||||
parser.add_argument('--checkpoint_id', | ||||
default="stabilityai/stable-diffusion-2-1", | ||||
help="checkpoint id on huggingface") | ||||
parser.add_argument('--prompts', nargs='+', | ||||
help='sequence of prompts') | ||||
parser.add_argument('--seeds', type=int, nargs='+', | ||||
help='seed for each prompt') | ||||
parser.add_argument('--num_interpolation_steps', type=int, nargs='+', | ||||
help='number of steps between each image') | ||||
parser.add_argument('--output_dir', default="dreams", | ||||
help='output directory') | ||||
parser.add_argument('--name', | ||||
help='output sub-directory') | ||||
parser.add_argument('--fps', type=int, default=10, | ||||
help='frames per second') | ||||
parser.add_argument('--guidance_scale', type=float, default=7.5, | ||||
help='diffusion guidance scale') | ||||
parser.add_argument('--num_inference_steps', type=int, default=50, | ||||
help='number of diffusion inference steps') | ||||
parser.add_argument('--height', type=int, default=512, | ||||
help='output image height') | ||||
parser.add_argument('--width', type=int, default=512, | ||||
help='output image width') | ||||
parser.add_argument('--upsample', action='store_true', | ||||
help='upscale x4 using Real-ESRGAN') | ||||
parser.add_argument('--batch_size', type=int, default=1, | ||||
help='batch size') | ||||
parser.add_argument('--audio_filepath', | ||||
help='path to audio file') | ||||
parser.add_argument('--audio_offsets', type=int, nargs='+', | ||||
help='audio offset for each prompt') | ||||
parser.add_argument('--negative_prompt', | ||||
help='negative prompt (one for all images)') | ||||
|
||||
parser.add_argument('--cfg', | ||||
help='yaml config file (overwrites other options)') | ||||
|
||||
return parser | ||||
|
||||
|
||||
def parse_args(parser): | ||||
args = parser.parse_args() | ||||
|
||||
# read config file | ||||
if args.cfg is not None: | ||||
with open(args.cfg) as f: | ||||
cfg = yaml.safe_load(f) | ||||
for key, val in cfg.items(): | ||||
if hasattr(args, key): | ||||
setattr(args, key, val) | ||||
else: | ||||
raise ValueError(f'bad field in config file: {key}') | ||||
|
||||
# check for prompts | ||||
if args.prompts is None: | ||||
raise ValueError('no prompt provided') | ||||
if args.seeds is None: | ||||
args.seeds = [random.getrandbits(16) for _ in args.prompts] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've been using randint instead in this scenario, kinda like this though :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using multiple methods for random numbers seems like a good idea the more I think about it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wdym @Atomic-Germ ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I rescind my comment, it was a little half-baked.. |
||||
|
||||
# check audio arguments | ||||
if args.audio_filepath is not None and args.audio_offsets is None: | ||||
raise ValueError('must provide audio_offsets when providing ' | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes me wonder if this should just be raised in the pipeline code itself instead of the parser (if its not already) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same goes for many of the other raised errors in this script There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe. That's a design question IMO. Do we want to raise errors to the unadvised CLI user as early as possible, while trusting that the developer who writes his owns scripts knows what they are doing? Or do we want to raise errors as close to the problematic code/as late as possible but such that it propagates? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. I'm fine with the way you did it here :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reason I say it though is that |
||||
'audio_filepath') | ||||
if args.audio_offsets is not None and args.audio_filepath is None: | ||||
raise ValueError('must provide audio_filepath when providing ' | ||||
'audio_offsets') | ||||
|
||||
# check lengths | ||||
if args.audio_offsets is not None: | ||||
if not len(args.prompts) == len(args.seeds) == len(args.audio_offsets): | ||||
raise ValueError('prompts, seeds and audio_offsets must have same ' | ||||
f'length, got lengths {len(args.prompts)}, ' | ||||
f'{len(args.seeds)} and ' | ||||
f'{len(args.audio_offsets)} respectively') | ||||
else: | ||||
if not len(args.prompts) == len(args.seeds): | ||||
raise ValueError('prompts and seeds must have same length, got ' | ||||
f'lengths {len(args.prompts)} and ' | ||||
f'{len(args.seeds)} respectively') | ||||
|
||||
# set num_interpolation_steps | ||||
if args.audio_offsets is not None \ | ||||
and args.num_interpolation_steps is not None: | ||||
raise ValueError('cannot provide both audio_offsets and ' | ||||
'num_interpolation_steps') | ||||
elif args.audio_offsets is not None: | ||||
args.num_interpolation_steps = [ | ||||
(b-a)*args.fps for a, b in zip( | ||||
args.audio_offsets, args.audio_offsets[1:] | ||||
) | ||||
] | ||||
elif args.num_interpolation_steps is not None \ | ||||
and not len(args.num_interpolation_steps) == len(args.prompts)-1: | ||||
raise ValueError('num_interpolation_steps must have length ' | ||||
f'len(prompts)-1, got ' | ||||
f'{len(args.num_interpolation_steps)} != ' | ||||
f'{len(args.prompts)-1}') | ||||
else: | ||||
args.num_interpolation_steps = args.fps*10 # 10 second video | ||||
|
||||
return args | ||||
|
||||
|
||||
def main(): | ||||
parser = init_arg_parser() | ||||
args = parse_args(parser) | ||||
|
||||
pipe = StableDiffusionWalkPipeline.from_pretrained( | ||||
args.checkpoint_id, | ||||
torch_dtype=torch.float16, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think hardcoding dtype here is also a no-no I'm afraid. Let's think of a nicer way to infer this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. has to support MPS/GPU/TPU There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. on second thought, no tpu as you'd have to use the other pipeline There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.backends.cuda.is_available() else "cpu" then use to(device) in place of to("cuda") and torch_dtype=torch_dtype There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep will change that |
||||
revision="fp16", | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think guidance in diffusers these days is erring towards not specifying a revision. Need to check if that only applies to newest versions, etc. Definitely hardcoding here is a no-no though.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, will add this as an option |
||||
feature_extractor=None, | ||||
safety_checker=None, | ||||
).to("cuda") | ||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoding this likely bad idea too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops that slipped my mind, these should be options too |
||||
pipe.scheduler.config | ||||
) | ||||
|
||||
pipe.walk( | ||||
prompts=args.prompts, | ||||
seeds=args.seeds, | ||||
num_interpolation_steps=args.num_interpolation_steps, | ||||
output_dir=args.output_dir, | ||||
name=args.name, | ||||
fps=args.fps, | ||||
num_inference_steps=args.num_inference_steps, | ||||
guidance_scale=args.guidance_scale, | ||||
height=args.height, | ||||
width=args.width, | ||||
upsample=args.upsample, | ||||
batch_size=args.batch_size, | ||||
audio_filepath=args.audio_filepath, | ||||
audio_start_sec=None if args.audio_offsets is None else args.audio_offsets[0], | ||||
negative_prompt=args.negative_prompt, | ||||
) | ||||
|
||||
|
||||
if __name__ == '__main__': | ||||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're already installing
fire
with the requirements of the package, maybe lets just use that instead? I can update to do this so its not a hassle for you :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not too familiar with
fire
but I can give it a try. Tho after quickly skimming the docs, while this would considerably reduce boilerplate, I think I prefer the flexibility ofargparse
. E.g. I prefer callingover
Moreover I can feel some dirty hacking would be required to keep support for argument provision through config file using the
--cfg
option, which is an important feature IMO.Let me know what you think. If this is something you really require then I will give it a shot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I think I agree with you! will have a look when I can