diff --git a/requirements.txt b/requirements.txt index e22303037..e96c5e652 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,6 +35,7 @@ s3tokenizer conformer==0.3.2 spacy_pkuseg spacy==3.8.4 +gradio_rangeslider # Vision & segmentation opencv-python>=4.12.0.88 diff --git a/shared/gradio/ui_styles.css b/shared/gradio/ui_styles.css index f55edf64e..ec3d90cdc 100644 --- a/shared/gradio/ui_styles.css +++ b/shared/gradio/ui_styles.css @@ -364,3 +364,14 @@ user-select: text; overflow: visible !important; } .tabitem {padding-top:0px} +.rule-row { margin-bottom: 8px; align-items: center !important; display: flex; gap: 8px; } +.rule-card { background-color: var(--background-fill-secondary); padding: 8px 12px; border-radius: 6px; border: 1px solid var(--border-color-primary); flex-grow: 1; margin-bottom: 0 !important; } +.rule-card p { margin-bottom: 0 !important; } +.delete-btn { + min-width: 42px !important; + max-width: 42px !important; + height: 42px !important; + padding: 0 !important; + align-self: center; +} +#refiner-input-row { align-items: center; } \ No newline at end of file diff --git a/shared/utils/self_refiner.py b/shared/utils/self_refiner.py index 17ea4b489..9d64705e0 100644 --- a/shared/utils/self_refiner.py +++ b/shared/utils/self_refiner.py @@ -1,5 +1,6 @@ import torch import copy +import uuid from diffusers.utils.torch_utils import randn_tensor def is_int_string(s: str) -> bool: @@ -9,59 +10,49 @@ def is_int_string(s: str) -> bool: except ValueError: return False -def _normalize_single_self_refiner_plan(plan_str): - entries = [] - for chunk in plan_str.split(","): - chunk = chunk.strip() - if not chunk: - continue - if ":" not in chunk: - return "", "Self-refine plan entries must be in 'start-end:steps' format." - range_part, steps_part = chunk.split(":", 1) - range_part = range_part.strip() - steps_part = steps_part.strip() - if not steps_part: - return "", "Self-refine plan entries must include a step count." - if "-" in range_part: - start_s, end_s = range_part.split("-", 1) - else: - start_s = end_s = range_part - start_s = start_s.strip() - if not is_int_string(start_s): - return "", "Self-refine plan start position must be an integer." - end_s = end_s.strip() - if not is_int_string(end_s): - return "", "Self-refine plan end position must be an integer." - if not is_int_string(steps_part): - return "", "Self-refine plan steps part must be an integer." - - entries.append({ - "start": int(start_s), - "end": int(end_s), - "steps": int(steps_part), - }) - plan = entries - return plan, "" - - -def normalize_self_refiner_plan(plan_str, max_plans: int = 1): - if plan_str is None: - plan_str = "" - if max_plans is None or max_plans < 1: - max_plans = 1 - segments = [seg.strip() for seg in str(plan_str).split(";")] - if len(segments) > max_plans: - return [], f"Self-refiner supports up to {max_plans} plan(s); remove extra ';' separators." - plans = [] - for seg in segments: - if not seg: - plans.append([]) - continue - plan, error = _normalize_single_self_refiner_plan(seg) - if error: - return [], error - plans.append(plan) - return plans, "" +def normalize_self_refiner_plan(plan_input, max_plans: int = 1): + default_plan = [ + {"start": 1, "end": 5, "steps": 3}, + {"start": 6, "end": 13, "steps": 1}, + ] + if len(plan_input) > max_plans: + return [], f"Self-refiner supports up to {max_plans} plan(s); found {len(plan_input)}." + if not plan_input or not isinstance(plan_input, list): + return [default_plan], "" + + return [plan_input], "" + +def ensure_refiner_list(plan_data): + if not isinstance(plan_data, list): + return [] + for rule in plan_data: + if "id" not in rule: + rule["id"] = str(uuid.uuid4()) + return plan_data + +def add_refiner_rule(current_rules, range_val, steps_val): + new_start, new_end = int(range_val[0]), int(range_val[1]) + + if new_start >= new_end: + from gradio import Error + raise Error(f"Start step ({new_start}) must be smaller than End step ({new_end}).") + + for rule in current_rules: + if new_start <= rule['end'] and new_end >= rule['start']: + from gradio import Error + raise Error(f"Overlap detected! Steps {new_start}-{new_end} conflict with existing rule {rule['start']}-{rule['end']}.") + + new_rule = { + "id": str(uuid.uuid4()), + "start": new_start, + "end": new_end, + "steps": int(steps_val) + } + updated_list = current_rules + [new_rule] + return sorted(updated_list, key=lambda x: x['start']) + +def remove_refiner_rule(current_rules, rule_id): + return [r for r in current_rules if r["id"] != rule_id] class PnPHandler: def __init__(self, stochastic_plan, ths_uncertainty=0.0, p_norm=1, certain_percentage=0.999, channel_dim: int = 1): @@ -309,20 +300,8 @@ def restore_func(saved_state): return latents, sample_scheduler def create_self_refiner_handler(pnp_plan, pnp_f_uncertainty, pnp_p_norm, pnp_certain_percentage, channel_dim: int = 1): - stochastic_plan = None - if isinstance(pnp_plan, list): - stochastic_plan = pnp_plan - elif len(pnp_plan): - plans, _ = normalize_self_refiner_plan(pnp_plan, max_plans=1) - if plans: - stochastic_plan = plans[0] - - if not stochastic_plan: - # Default plan from paper/code - stochastic_plan = [ - {"start": 1, "end": 5, "steps": 3}, - {"start": 6, "end": 13, "steps": 1}, - ] + plans, _ = normalize_self_refiner_plan(pnp_plan, max_plans=1) + stochastic_plan = plans[0] return PnPHandler( stochastic_plan, diff --git a/wgp.py b/wgp.py index 115c9aa96..179ac5a27 100644 --- a/wgp.py +++ b/wgp.py @@ -51,6 +51,7 @@ from huggingface_hub import hf_hub_download, snapshot_download from shared.utils import files_locator as fl from shared.gradio.audio_gallery import AudioGallery +from shared.utils.self_refiner import normalize_self_refiner_plan, ensure_refiner_list, add_refiner_rule, remove_refiner_rule import torch import gc import traceback @@ -68,6 +69,7 @@ import glob import cv2 import html +from gradio_rangeslider import RangeSlider import re from transformers.utils import logging logging.set_verbosity_error @@ -765,13 +767,15 @@ def ret(): model_mode = None if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"): gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting") - if self_refiner_setting != 0 and len(self_refiner_plan): - from shared.utils.self_refiner import normalize_self_refiner_plan - max_plans = model_def.get("self_refiner_max_plans", 1) - _, error = normalize_self_refiner_plan(self_refiner_plan, max_plans=max_plans) - if len(error): - gr.Info(error) - return ret() + if self_refiner_setting != 0: + if isinstance(self_refiner_plan, list): + max_plans = model_def.get("self_refiner_max_plans", 1) + _, error = normalize_self_refiner_plan(self_refiner_plan, max_plans=max_plans) + if len(error): + gr.Info(error) + return ret() + else: + self_refiner_plan = [] if not model_def.get("motion_amplitude", False): motion_amplitude = 1. if "vae" in spatial_upsampling: @@ -10394,8 +10398,38 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai with gr.Column(visible = model_def.get("self_refiner", False)) as self_refiner_col: gr.Markdown("Self-Refining Video Sampling (PnP) - should improve quality of Motion") - self_refiner_setting = gr.Dropdown( choices=[("Disabled", 0),("Enabled with P1-Norm", 1), ("Enabled with P2-Norm", 2), ], value=ui_get("self_refiner_setting", 0), scale = 1, label="Self Refiner", ) - self_refiner_plan = gr.Textbox( value=ui_get("self_refiner_plan", ""), label="P&P Plan (start-end:steps, comma-separated)", lines=1, placeholder="2-5:3,6-13:1" ) + self_refiner_setting = gr.Dropdown(choices=[("Disabled", 0),("Enabled with P1-Norm", 1), ("Enabled with P2-Norm", 2)], value=ui_get("self_refiner_setting", 0), scale=1, label="Self Refiner") + + refiner_val = ensure_refiner_list(ui_get("self_refiner_plan", [])) + self_refiner_plan = refiner_val if update_form else gr.State(value=refiner_val) + + with gr.Group(visible=(update_form and ui_get("self_refiner_setting", 0) > 0)) as self_refiner_rules_ui: + gr.Markdown("### Refiner Plan") + + with gr.Row(elem_id="refiner-input-row"): + refiner_range = RangeSlider(minimum=0, maximum=100, value=(0, 10), step=1, label="Step Range", info="Start - End", scale=3) + refiner_mult = gr.Slider(label="Iterations", value=3, minimum=1, maximum=5, step=1, scale=2) + refiner_add_btn = gr.Button("➕ Add", variant="primary", scale=0, min_width=100) + + if not update_form: + refiner_add_btn.click(fn=add_refiner_rule, inputs=[self_refiner_plan, refiner_range, refiner_mult], outputs=[self_refiner_plan]) + self_refiner_setting.change(fn=lambda s: gr.update(visible=s > 0), inputs=[self_refiner_setting], outputs=[self_refiner_rules_ui]) + + @gr.render(inputs=self_refiner_plan) + def render_refiner_plans(plans): + if not plans: + gr.Markdown("No plans defined. Using defaults: Steps 2-5 (3x), Steps 6-13 (1x).") + return + for plan in plans: + with gr.Row(elem_classes="rule-row"): + text_display = f"Steps **{plan['start']} - {plan['end']}** : **{plan['steps']}x** iterations" + gr.Markdown(text_display, elem_classes="rule-card") + gr.Button("✖", variant="stop", scale=0, elem_classes="delete-btn").click( + fn=remove_refiner_rule, + inputs=[self_refiner_plan, gr.State(plan["id"])], + outputs=[self_refiner_plan] + ) + with gr.Row(): self_refiner_f_uncertainty = gr.Slider(0.0, 1.0, value=ui_get("self_refiner_f_uncertainty", 0.1), step=0.01, label="Uncertainty Threshold", show_reset_button= False) self_refiner_certain_percentage = gr.Slider(0.0, 1.0, value=ui_get("self_refiner_certain_percentage", 0.999), step=0.001, label="Certainty Percentage Skip", show_reset_button= False)