From 19f30e7ca77716138105225f18abf6422b056629 Mon Sep 17 00:00:00 2001 From: glx Date: Mon, 26 Jun 2023 16:39:28 +0800 Subject: [PATCH 1/4] build a download_model.bat script for windows user --- .gitignore | 2 +- scripts/download_model.bat | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 scripts/download_model.bat diff --git a/.gitignore b/.gitignore index 40fba4c..950aa8b 100644 --- a/.gitignore +++ b/.gitignore @@ -98,7 +98,7 @@ ENV/ [Ll]ib [Ll]ib64 [Ll]ocal -[Ss]cripts +# [Ss]cripts pyvenv.cfg .venv pip-selfcheck.json diff --git a/scripts/download_model.bat b/scripts/download_model.bat new file mode 100644 index 0000000..67101d4 --- /dev/null +++ b/scripts/download_model.bat @@ -0,0 +1,23 @@ +@echo off +mkdir checkpoints +cd checkpoints + +powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://storage.googleapis.com/self-distilled-stylegan/lions_512_pytorch.pkl', 'lions_512_pytorch.pkl')" +ren lions_512_pytorch.pkl stylegan2_lions_512_pytorch.pkl + +powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://storage.googleapis.com/self-distilled-stylegan/dogs_1024_pytorch.pkl', 'dogs_1024_pytorch.pkl')" +ren dogs_1024_pytorch.pkl stylegan2_dogs_1024_pytorch.pkl + +powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://storage.googleapis.com/self-distilled-stylegan/horses_256_pytorch.pkl', 'horses_256_pytorch.pkl')" +ren horses_256_pytorch.pkl stylegan2_horses_256_pytorch.pkl + +powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://storage.googleapis.com/self-distilled-stylegan/elephants_512_pytorch.pkl', 'elephants_512_pytorch.pkl')" +ren elephants_512_pytorch.pkl stylegan2_elephants_512_pytorch.pkl + +powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl', 'stylegan2-ffhq-512x512.pkl')" +powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqcat-512x512.pkl', 'stylegan2-afhqcat-512x512.pkl')" +powershell -Command "(New-Object System.Net.WebClient).DownloadFile('http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl', 'stylegan2-car-config-f.pkl')" +powershell -Command "(New-Object System.Net.WebClient).DownloadFile('http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl', 'stylegan2-cat-config-f.pkl')" + +echo "Done" +pause From 133776b16303865431146abaef637596d0377979 Mon Sep 17 00:00:00 2001 From: glx Date: Mon, 26 Jun 2023 16:43:59 +0800 Subject: [PATCH 2/4] modified .gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 950aa8b..24b5204 100644 --- a/.gitignore +++ b/.gitignore @@ -98,7 +98,8 @@ ENV/ [Ll]ib [Ll]ib64 [Ll]ocal -# [Ss]cripts +[Ss]cripts +!scripts\download_model.bat pyvenv.cfg .venv pip-selfcheck.json From 2eb2d29153e31179cae0e542300842ae30a92028 Mon Sep 17 00:00:00 2001 From: glx Date: Mon, 26 Jun 2023 16:50:05 +0800 Subject: [PATCH 3/4] fix init pkl bug --- visualizer_drag_gradio.py | 506 ++++++++++++++++++++------------------ 1 file changed, 261 insertions(+), 245 deletions(-) diff --git a/visualizer_drag_gradio.py b/visualizer_drag_gradio.py index 699787d..023ce16 100644 --- a/visualizer_drag_gradio.py +++ b/visualizer_drag_gradio.py @@ -9,19 +9,24 @@ from PIL import Image import dnnlib -from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image, - get_latest_points_pair, get_valid_mask, - on_change_single_global_state) +from gradio_utils import ( + ImageMask, + draw_mask_on_image, + draw_points_on_image, + get_latest_points_pair, + get_valid_mask, + on_change_single_global_state, +) from viz.renderer import Renderer, add_watermark_np parser = ArgumentParser() -parser.add_argument('--share', action='store_true') -parser.add_argument('--cache-dir', type=str, default='./checkpoints') +parser.add_argument("--share", action="store_true") +parser.add_argument("--cache-dir", type=str, default="./checkpoints") args = parser.parse_args() cache_dir = args.cache_dir -device = 'cuda' +device = "cuda" def reverse_point_pairs(points): @@ -38,17 +43,18 @@ def clear_state(global_state, target=None): 2. set global_state['mask'] as full-one mask. """ if target is None: - target = ['point', 'mask'] + target = ["point", "mask"] if not isinstance(target, list): target = [target] - if 'point' in target: - global_state['points'] = dict() - print('Clear Points State!') - if 'mask' in target: + if "point" in target: + global_state["points"] = dict() + print("Clear Points State!") + if "mask" in target: image_raw = global_state["images"]["image_raw"] - global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]), - dtype=np.uint8) - print('Clear mask State!') + global_state["mask"] = np.ones( + (image_raw.size[1], image_raw.size[0]), dtype=np.uint8 + ) + print("Clear mask State!") return global_state @@ -67,43 +73,46 @@ def init_images(global_state): else: state = global_state - state['renderer'].init_network( - state['generator_params'], # res - valid_checkpoints_dict[state['pretrained_weight']], # pkl - state['params']['seed'], # w0_seed, + state["renderer"].init_network( + state["generator_params"], # res + valid_checkpoints_dict[state["pretrained_weight"]], # pkl + state["params"]["seed"], # w0_seed, None, # w_load - state['params']['latent_space'] == 'w+', # w_plus - 'const', - state['params']['trunc_psi'], # trunc_psi, - state['params']['trunc_cutoff'], # trunc_cutoff, + state["params"]["latent_space"] == "w+", # w_plus + "const", + state["params"]["trunc_psi"], # trunc_psi, + state["params"]["trunc_cutoff"], # trunc_cutoff, None, # input_transform - state['params']['lr'] # lr, + state["params"]["lr"], # lr, + ) + + state["renderer"]._render_drag_impl( + state["generator_params"], is_drag=False, to_pil=True ) - state['renderer']._render_drag_impl(state['generator_params'], - is_drag=False, - to_pil=True) - - init_image = state['generator_params'].image - state['images']['image_orig'] = init_image - state['images']['image_raw'] = init_image - state['images']['image_show'] = Image.fromarray( - add_watermark_np(np.array(init_image))) - state['mask'] = np.ones((init_image.size[1], init_image.size[0]), - dtype=np.uint8) + init_image = state["generator_params"].image + state["images"]["image_orig"] = init_image + state["images"]["image_raw"] = init_image + state["images"]["image_show"] = Image.fromarray( + add_watermark_np(np.array(init_image)) + ) + state["mask"] = np.ones((init_image.size[1], init_image.size[0]), dtype=np.uint8) return global_state def update_image_draw(image, points, mask, show_mask, global_state=None): - image_draw = draw_points_on_image(image, points) - if show_mask and mask is not None and not (mask == 0).all() and not ( - mask == 1).all(): + if ( + show_mask + and mask is not None + and not (mask == 0).all() + and not (mask == 1).all() + ): image_draw = draw_mask_on_image(image_draw, mask) image_draw = Image.fromarray(add_watermark_np(np.array(image_draw))) if global_state is not None: - global_state['images']['image_show'] = image_draw + global_state["images"]["image_show"] = image_draw return image_draw @@ -115,102 +124,97 @@ def preprocess_mask_info(global_state, image): 2.2 global_state is add_mask: """ if isinstance(image, dict): - last_mask = get_valid_mask(image['mask']) + last_mask = get_valid_mask(image["mask"]) else: last_mask = None - mask = global_state['mask'] + mask = global_state["mask"] # mask in global state is a placeholder with all 1. if (mask == 1).all(): mask = last_mask # last_mask = global_state['last_mask'] - editing_mode = global_state['editing_state'] + editing_mode = global_state["editing_state"] if last_mask is None: return global_state - if editing_mode == 'remove_mask': + if editing_mode == "remove_mask": updated_mask = np.clip(mask - last_mask, 0, 1) - print(f'Last editing_state is {editing_mode}, do remove.') - elif editing_mode == 'add_mask': + print(f"Last editing_state is {editing_mode}, do remove.") + elif editing_mode == "add_mask": updated_mask = np.clip(mask + last_mask, 0, 1) - print(f'Last editing_state is {editing_mode}, do add.') + print(f"Last editing_state is {editing_mode}, do add.") else: updated_mask = mask - print(f'Last editing_state is {editing_mode}, ' - 'do nothing to mask.') + print(f"Last editing_state is {editing_mode}, " "do nothing to mask.") - global_state['mask'] = updated_mask + global_state["mask"] = updated_mask # global_state['last_mask'] = None # clear buffer return global_state valid_checkpoints_dict = { - f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f) + f.split("/")[-1].split(".")[0]: osp.join(cache_dir, f) for f in os.listdir(cache_dir) - if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f))) + if (f.endswith("pkl") and osp.exists(osp.join(cache_dir, f))) } -print(f'File under cache_dir ({cache_dir}):') +print(f"File under cache_dir ({cache_dir}):") print(os.listdir(cache_dir)) -print('Valid checkpoint file:') +print("Valid checkpoint file:") print(valid_checkpoints_dict) -init_pkl = 'stylegan_human_v2_512' +init_pkl = list(valid_checkpoints_dict.keys())[4] with gr.Blocks() as app: - # renderer = Renderer() - global_state = gr.State({ - "images": { - # image_orig: the original image, change with seed/model is changed - # image_raw: image with mask and points, change durning optimization - # image_show: image showed on screen - }, - "temporal_params": { - # stop - }, - 'mask': - None, # mask for visualization, 1 for editing and 0 for unchange - 'last_mask': None, # last edited mask - 'show_mask': True, # add button - "generator_params": dnnlib.EasyDict(), - "params": { - "seed": 0, - "motion_lambda": 20, - "r1_in_pixels": 3, - "r2_in_pixels": 12, - "magnitude_direction_in_pixels": 1.0, - "latent_space": "w+", - "trunc_psi": 0.7, - "trunc_cutoff": None, - "lr": 0.001, - }, - "device": device, - "draw_interval": 1, - "renderer": Renderer(disable_timing=True), - "points": {}, - "curr_point": None, - "curr_type_point": "start", - 'editing_state': 'add_points', - 'pretrained_weight': init_pkl - }) + global_state = gr.State( + { + "images": { + # image_orig: the original image, change with seed/model is changed + # image_raw: image with mask and points, change durning optimization + # image_show: image showed on screen + }, + "temporal_params": { + # stop + }, + "mask": None, # mask for visualization, 1 for editing and 0 for unchange + "last_mask": None, # last edited mask + "show_mask": True, # add button + "generator_params": dnnlib.EasyDict(), + "params": { + "seed": 0, + "motion_lambda": 20, + "r1_in_pixels": 3, + "r2_in_pixels": 12, + "magnitude_direction_in_pixels": 1.0, + "latent_space": "w+", + "trunc_psi": 0.7, + "trunc_cutoff": None, + "lr": 0.001, + }, + "device": device, + "draw_interval": 1, + "renderer": Renderer(disable_timing=True), + "points": {}, + "curr_point": None, + "curr_type_point": "start", + "editing_state": "add_points", + "pretrained_weight": init_pkl, + } + ) # init image global_state = init_images(global_state) with gr.Row(): - with gr.Row(): - # Left --> tools with gr.Column(scale=3): - # Pickle with gr.Row(): - with gr.Column(scale=1, min_width=10): - gr.Markdown(value='Pickle', show_label=False) + gr.Markdown(value="Pickle", show_label=False) with gr.Column(scale=4, min_width=10): form_pretrained_dropdown = gr.Dropdown( @@ -222,71 +226,71 @@ def preprocess_mask_info(global_state, image): # Latent with gr.Row(): with gr.Column(scale=1, min_width=10): - gr.Markdown(value='Latent', show_label=False) + gr.Markdown(value="Latent", show_label=False) with gr.Column(scale=4, min_width=10): form_seed_number = gr.Number( - value=global_state.value['params']['seed'], + value=global_state.value["params"]["seed"], interactive=True, label="Seed", ) form_lr_number = gr.Number( value=global_state.value["params"]["lr"], interactive=True, - label="Step Size") + label="Step Size", + ) with gr.Row(): with gr.Column(scale=2, min_width=10): form_reset_image = gr.Button("Reset Image") with gr.Column(scale=3, min_width=10): form_latent_space = gr.Radio( - ['w', 'w+'], - value=global_state.value['params'] - ['latent_space'], + ["w", "w+"], + value=global_state.value["params"]["latent_space"], interactive=True, - label='Latent space to optimize', + label="Latent space to optimize", show_label=False, ) # Drag with gr.Row(): with gr.Column(scale=1, min_width=10): - gr.Markdown(value='Drag', show_label=False) + gr.Markdown(value="Drag", show_label=False) with gr.Column(scale=4, min_width=10): with gr.Row(): with gr.Column(scale=1, min_width=10): - enable_add_points = gr.Button('Add Points') + enable_add_points = gr.Button("Add Points") with gr.Column(scale=1, min_width=10): - undo_points = gr.Button('Reset Points') + undo_points = gr.Button("Reset Points") with gr.Row(): with gr.Column(scale=1, min_width=10): form_start_btn = gr.Button("Start") with gr.Column(scale=1, min_width=10): form_stop_btn = gr.Button("Stop") - form_steps_number = gr.Number(value=0, - label="Steps", - interactive=False) + form_steps_number = gr.Number( + value=0, label="Steps", interactive=False + ) # Mask with gr.Row(): with gr.Column(scale=1, min_width=10): - gr.Markdown(value='Mask', show_label=False) + gr.Markdown(value="Mask", show_label=False) with gr.Column(scale=4, min_width=10): - enable_add_mask = gr.Button('Edit Flexible Area') + enable_add_mask = gr.Button("Edit Flexible Area") with gr.Row(): with gr.Column(scale=1, min_width=10): form_reset_mask_btn = gr.Button("Reset mask") with gr.Column(scale=1, min_width=10): show_mask = gr.Checkbox( - label='Show Mask', - value=global_state.value['show_mask'], - show_label=False) + label="Show Mask", + value=global_state.value["show_mask"], + show_label=False, + ) with gr.Row(): form_lambda_number = gr.Number( - value=global_state.value["params"] - ["motion_lambda"], + value=global_state.value["params"]["motion_lambda"], interactive=True, label="Lambda", ) @@ -295,16 +299,18 @@ def preprocess_mask_info(global_state, image): value=global_state.value["draw_interval"], label="Draw Interval (steps)", interactive=True, - visible=False) + visible=False, + ) # Right --> Image with gr.Column(scale=8): form_image = ImageMask( - value=global_state.value['images']['image_show'], - brush_radius=20).style( - width=768, - height=768) # NOTE: hard image size code here. - gr.Markdown(""" + value=global_state.value["images"]["image_show"], brush_radius=20 + ).style( + width=768, height=768 + ) # NOTE: hard image size code here. + gr.Markdown( + """ ## Quick Start 1. Select desired `Pretrained Model` and adjust `Seed` to generate an @@ -323,8 +329,10 @@ def preprocess_mask_info(global_state, image): mask (this has the same effect as `Reset Image` button). 3. Click `Edit Flexible Area` to create a mask and constrain the unmasked region to remain unchanged. - """) - gr.HTML(""" + """ + ) + gr.HTML( + """