diff --git a/requirements.txt b/requirements.txt
index e22303037..e96c5e652 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -35,6 +35,7 @@ s3tokenizer
conformer==0.3.2
spacy_pkuseg
spacy==3.8.4
+gradio_rangeslider
# Vision & segmentation
opencv-python>=4.12.0.88
diff --git a/shared/gradio/ui_styles.css b/shared/gradio/ui_styles.css
index f55edf64e..ec3d90cdc 100644
--- a/shared/gradio/ui_styles.css
+++ b/shared/gradio/ui_styles.css
@@ -364,3 +364,14 @@ user-select: text;
overflow: visible !important;
}
.tabitem {padding-top:0px}
+.rule-row { margin-bottom: 8px; align-items: center !important; display: flex; gap: 8px; }
+.rule-card { background-color: var(--background-fill-secondary); padding: 8px 12px; border-radius: 6px; border: 1px solid var(--border-color-primary); flex-grow: 1; margin-bottom: 0 !important; }
+.rule-card p { margin-bottom: 0 !important; }
+.delete-btn {
+ min-width: 42px !important;
+ max-width: 42px !important;
+ height: 42px !important;
+ padding: 0 !important;
+ align-self: center;
+}
+#refiner-input-row { align-items: center; }
\ No newline at end of file
diff --git a/shared/utils/self_refiner.py b/shared/utils/self_refiner.py
index 17ea4b489..9d64705e0 100644
--- a/shared/utils/self_refiner.py
+++ b/shared/utils/self_refiner.py
@@ -1,5 +1,6 @@
import torch
import copy
+import uuid
from diffusers.utils.torch_utils import randn_tensor
def is_int_string(s: str) -> bool:
@@ -9,59 +10,49 @@ def is_int_string(s: str) -> bool:
except ValueError:
return False
-def _normalize_single_self_refiner_plan(plan_str):
- entries = []
- for chunk in plan_str.split(","):
- chunk = chunk.strip()
- if not chunk:
- continue
- if ":" not in chunk:
- return "", "Self-refine plan entries must be in 'start-end:steps' format."
- range_part, steps_part = chunk.split(":", 1)
- range_part = range_part.strip()
- steps_part = steps_part.strip()
- if not steps_part:
- return "", "Self-refine plan entries must include a step count."
- if "-" in range_part:
- start_s, end_s = range_part.split("-", 1)
- else:
- start_s = end_s = range_part
- start_s = start_s.strip()
- if not is_int_string(start_s):
- return "", "Self-refine plan start position must be an integer."
- end_s = end_s.strip()
- if not is_int_string(end_s):
- return "", "Self-refine plan end position must be an integer."
- if not is_int_string(steps_part):
- return "", "Self-refine plan steps part must be an integer."
-
- entries.append({
- "start": int(start_s),
- "end": int(end_s),
- "steps": int(steps_part),
- })
- plan = entries
- return plan, ""
-
-
-def normalize_self_refiner_plan(plan_str, max_plans: int = 1):
- if plan_str is None:
- plan_str = ""
- if max_plans is None or max_plans < 1:
- max_plans = 1
- segments = [seg.strip() for seg in str(plan_str).split(";")]
- if len(segments) > max_plans:
- return [], f"Self-refiner supports up to {max_plans} plan(s); remove extra ';' separators."
- plans = []
- for seg in segments:
- if not seg:
- plans.append([])
- continue
- plan, error = _normalize_single_self_refiner_plan(seg)
- if error:
- return [], error
- plans.append(plan)
- return plans, ""
+def normalize_self_refiner_plan(plan_input, max_plans: int = 1):
+ default_plan = [
+ {"start": 1, "end": 5, "steps": 3},
+ {"start": 6, "end": 13, "steps": 1},
+ ]
+ if len(plan_input) > max_plans:
+ return [], f"Self-refiner supports up to {max_plans} plan(s); found {len(plan_input)}."
+ if not plan_input or not isinstance(plan_input, list):
+ return [default_plan], ""
+
+ return [plan_input], ""
+
+def ensure_refiner_list(plan_data):
+ if not isinstance(plan_data, list):
+ return []
+ for rule in plan_data:
+ if "id" not in rule:
+ rule["id"] = str(uuid.uuid4())
+ return plan_data
+
+def add_refiner_rule(current_rules, range_val, steps_val):
+ new_start, new_end = int(range_val[0]), int(range_val[1])
+
+ if new_start >= new_end:
+ from gradio import Error
+ raise Error(f"Start step ({new_start}) must be smaller than End step ({new_end}).")
+
+ for rule in current_rules:
+ if new_start <= rule['end'] and new_end >= rule['start']:
+ from gradio import Error
+ raise Error(f"Overlap detected! Steps {new_start}-{new_end} conflict with existing rule {rule['start']}-{rule['end']}.")
+
+ new_rule = {
+ "id": str(uuid.uuid4()),
+ "start": new_start,
+ "end": new_end,
+ "steps": int(steps_val)
+ }
+ updated_list = current_rules + [new_rule]
+ return sorted(updated_list, key=lambda x: x['start'])
+
+def remove_refiner_rule(current_rules, rule_id):
+ return [r for r in current_rules if r["id"] != rule_id]
class PnPHandler:
def __init__(self, stochastic_plan, ths_uncertainty=0.0, p_norm=1, certain_percentage=0.999, channel_dim: int = 1):
@@ -309,20 +300,8 @@ def restore_func(saved_state):
return latents, sample_scheduler
def create_self_refiner_handler(pnp_plan, pnp_f_uncertainty, pnp_p_norm, pnp_certain_percentage, channel_dim: int = 1):
- stochastic_plan = None
- if isinstance(pnp_plan, list):
- stochastic_plan = pnp_plan
- elif len(pnp_plan):
- plans, _ = normalize_self_refiner_plan(pnp_plan, max_plans=1)
- if plans:
- stochastic_plan = plans[0]
-
- if not stochastic_plan:
- # Default plan from paper/code
- stochastic_plan = [
- {"start": 1, "end": 5, "steps": 3},
- {"start": 6, "end": 13, "steps": 1},
- ]
+ plans, _ = normalize_self_refiner_plan(pnp_plan, max_plans=1)
+ stochastic_plan = plans[0]
return PnPHandler(
stochastic_plan,
diff --git a/wgp.py b/wgp.py
index 115c9aa96..179ac5a27 100644
--- a/wgp.py
+++ b/wgp.py
@@ -51,6 +51,7 @@
from huggingface_hub import hf_hub_download, snapshot_download
from shared.utils import files_locator as fl
from shared.gradio.audio_gallery import AudioGallery
+from shared.utils.self_refiner import normalize_self_refiner_plan, ensure_refiner_list, add_refiner_rule, remove_refiner_rule
import torch
import gc
import traceback
@@ -68,6 +69,7 @@
import glob
import cv2
import html
+from gradio_rangeslider import RangeSlider
import re
from transformers.utils import logging
logging.set_verbosity_error
@@ -765,13 +767,15 @@ def ret():
model_mode = None
if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"):
gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting")
- if self_refiner_setting != 0 and len(self_refiner_plan):
- from shared.utils.self_refiner import normalize_self_refiner_plan
- max_plans = model_def.get("self_refiner_max_plans", 1)
- _, error = normalize_self_refiner_plan(self_refiner_plan, max_plans=max_plans)
- if len(error):
- gr.Info(error)
- return ret()
+ if self_refiner_setting != 0:
+ if isinstance(self_refiner_plan, list):
+ max_plans = model_def.get("self_refiner_max_plans", 1)
+ _, error = normalize_self_refiner_plan(self_refiner_plan, max_plans=max_plans)
+ if len(error):
+ gr.Info(error)
+ return ret()
+ else:
+ self_refiner_plan = []
if not model_def.get("motion_amplitude", False): motion_amplitude = 1.
if "vae" in spatial_upsampling:
@@ -10394,8 +10398,38 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
with gr.Column(visible = model_def.get("self_refiner", False)) as self_refiner_col:
gr.Markdown("Self-Refining Video Sampling (PnP) - should improve quality of Motion")
- self_refiner_setting = gr.Dropdown( choices=[("Disabled", 0),("Enabled with P1-Norm", 1), ("Enabled with P2-Norm", 2), ], value=ui_get("self_refiner_setting", 0), scale = 1, label="Self Refiner", )
- self_refiner_plan = gr.Textbox( value=ui_get("self_refiner_plan", ""), label="P&P Plan (start-end:steps, comma-separated)", lines=1, placeholder="2-5:3,6-13:1" )
+ self_refiner_setting = gr.Dropdown(choices=[("Disabled", 0),("Enabled with P1-Norm", 1), ("Enabled with P2-Norm", 2)], value=ui_get("self_refiner_setting", 0), scale=1, label="Self Refiner")
+
+ refiner_val = ensure_refiner_list(ui_get("self_refiner_plan", []))
+ self_refiner_plan = refiner_val if update_form else gr.State(value=refiner_val)
+
+ with gr.Group(visible=(update_form and ui_get("self_refiner_setting", 0) > 0)) as self_refiner_rules_ui:
+ gr.Markdown("### Refiner Plan")
+
+ with gr.Row(elem_id="refiner-input-row"):
+ refiner_range = RangeSlider(minimum=0, maximum=100, value=(0, 10), step=1, label="Step Range", info="Start - End", scale=3)
+ refiner_mult = gr.Slider(label="Iterations", value=3, minimum=1, maximum=5, step=1, scale=2)
+ refiner_add_btn = gr.Button("➕ Add", variant="primary", scale=0, min_width=100)
+
+ if not update_form:
+ refiner_add_btn.click(fn=add_refiner_rule, inputs=[self_refiner_plan, refiner_range, refiner_mult], outputs=[self_refiner_plan])
+ self_refiner_setting.change(fn=lambda s: gr.update(visible=s > 0), inputs=[self_refiner_setting], outputs=[self_refiner_rules_ui])
+
+ @gr.render(inputs=self_refiner_plan)
+ def render_refiner_plans(plans):
+ if not plans:
+ gr.Markdown("No plans defined. Using defaults: Steps 2-5 (3x), Steps 6-13 (1x).")
+ return
+ for plan in plans:
+ with gr.Row(elem_classes="rule-row"):
+ text_display = f"Steps **{plan['start']} - {plan['end']}** : **{plan['steps']}x** iterations"
+ gr.Markdown(text_display, elem_classes="rule-card")
+ gr.Button("✖", variant="stop", scale=0, elem_classes="delete-btn").click(
+ fn=remove_refiner_rule,
+ inputs=[self_refiner_plan, gr.State(plan["id"])],
+ outputs=[self_refiner_plan]
+ )
+
with gr.Row():
self_refiner_f_uncertainty = gr.Slider(0.0, 1.0, value=ui_get("self_refiner_f_uncertainty", 0.1), step=0.01, label="Uncertainty Threshold", show_reset_button= False)
self_refiner_certain_percentage = gr.Slider(0.0, 1.0, value=ui_get("self_refiner_certain_percentage", 0.999), step=0.001, label="Certainty Percentage Skip", show_reset_button= False)