Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ s3tokenizer
conformer==0.3.2
spacy_pkuseg
spacy==3.8.4
gradio_rangeslider

# Vision & segmentation
opencv-python>=4.12.0.88
Expand Down
11 changes: 11 additions & 0 deletions shared/gradio/ui_styles.css
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,14 @@ user-select: text;
overflow: visible !important;
}
.tabitem {padding-top:0px}
.rule-row { margin-bottom: 8px; align-items: center !important; display: flex; gap: 8px; }
.rule-card { background-color: var(--background-fill-secondary); padding: 8px 12px; border-radius: 6px; border: 1px solid var(--border-color-primary); flex-grow: 1; margin-bottom: 0 !important; }
.rule-card p { margin-bottom: 0 !important; }
.delete-btn {
min-width: 42px !important;
max-width: 42px !important;
height: 42px !important;
padding: 0 !important;
align-self: center;
}
#refiner-input-row { align-items: center; }
113 changes: 46 additions & 67 deletions shared/utils/self_refiner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import copy
import uuid
from diffusers.utils.torch_utils import randn_tensor

def is_int_string(s: str) -> bool:
Expand All @@ -9,59 +10,49 @@ def is_int_string(s: str) -> bool:
except ValueError:
return False

def _normalize_single_self_refiner_plan(plan_str):
entries = []
for chunk in plan_str.split(","):
chunk = chunk.strip()
if not chunk:
continue
if ":" not in chunk:
return "", "Self-refine plan entries must be in 'start-end:steps' format."
range_part, steps_part = chunk.split(":", 1)
range_part = range_part.strip()
steps_part = steps_part.strip()
if not steps_part:
return "", "Self-refine plan entries must include a step count."
if "-" in range_part:
start_s, end_s = range_part.split("-", 1)
else:
start_s = end_s = range_part
start_s = start_s.strip()
if not is_int_string(start_s):
return "", "Self-refine plan start position must be an integer."
end_s = end_s.strip()
if not is_int_string(end_s):
return "", "Self-refine plan end position must be an integer."
if not is_int_string(steps_part):
return "", "Self-refine plan steps part must be an integer."

entries.append({
"start": int(start_s),
"end": int(end_s),
"steps": int(steps_part),
})
plan = entries
return plan, ""


def normalize_self_refiner_plan(plan_str, max_plans: int = 1):
if plan_str is None:
plan_str = ""
if max_plans is None or max_plans < 1:
max_plans = 1
segments = [seg.strip() for seg in str(plan_str).split(";")]
if len(segments) > max_plans:
return [], f"Self-refiner supports up to {max_plans} plan(s); remove extra ';' separators."
plans = []
for seg in segments:
if not seg:
plans.append([])
continue
plan, error = _normalize_single_self_refiner_plan(seg)
if error:
return [], error
plans.append(plan)
return plans, ""
def normalize_self_refiner_plan(plan_input, max_plans: int = 1):
default_plan = [
{"start": 1, "end": 5, "steps": 3},
{"start": 6, "end": 13, "steps": 1},
]
if len(plan_input) > max_plans:
return [], f"Self-refiner supports up to {max_plans} plan(s); found {len(plan_input)}."
if not plan_input or not isinstance(plan_input, list):
return [default_plan], ""

return [plan_input], ""

def ensure_refiner_list(plan_data):
if not isinstance(plan_data, list):
return []
for rule in plan_data:
if "id" not in rule:
rule["id"] = str(uuid.uuid4())
return plan_data

def add_refiner_rule(current_rules, range_val, steps_val):
new_start, new_end = int(range_val[0]), int(range_val[1])

if new_start >= new_end:
from gradio import Error
raise Error(f"Start step ({new_start}) must be smaller than End step ({new_end}).")

for rule in current_rules:
if new_start <= rule['end'] and new_end >= rule['start']:
from gradio import Error
raise Error(f"Overlap detected! Steps {new_start}-{new_end} conflict with existing rule {rule['start']}-{rule['end']}.")

new_rule = {
"id": str(uuid.uuid4()),
"start": new_start,
"end": new_end,
"steps": int(steps_val)
}
updated_list = current_rules + [new_rule]
return sorted(updated_list, key=lambda x: x['start'])

def remove_refiner_rule(current_rules, rule_id):
return [r for r in current_rules if r["id"] != rule_id]

class PnPHandler:
def __init__(self, stochastic_plan, ths_uncertainty=0.0, p_norm=1, certain_percentage=0.999, channel_dim: int = 1):
Expand Down Expand Up @@ -309,20 +300,8 @@ def restore_func(saved_state):
return latents, sample_scheduler

def create_self_refiner_handler(pnp_plan, pnp_f_uncertainty, pnp_p_norm, pnp_certain_percentage, channel_dim: int = 1):
stochastic_plan = None
if isinstance(pnp_plan, list):
stochastic_plan = pnp_plan
elif len(pnp_plan):
plans, _ = normalize_self_refiner_plan(pnp_plan, max_plans=1)
if plans:
stochastic_plan = plans[0]

if not stochastic_plan:
# Default plan from paper/code
stochastic_plan = [
{"start": 1, "end": 5, "steps": 3},
{"start": 6, "end": 13, "steps": 1},
]
plans, _ = normalize_self_refiner_plan(pnp_plan, max_plans=1)
stochastic_plan = plans[0]

return PnPHandler(
stochastic_plan,
Expand Down
52 changes: 43 additions & 9 deletions wgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from huggingface_hub import hf_hub_download, snapshot_download
from shared.utils import files_locator as fl
from shared.gradio.audio_gallery import AudioGallery
from shared.utils.self_refiner import normalize_self_refiner_plan, ensure_refiner_list, add_refiner_rule, remove_refiner_rule
import torch
import gc
import traceback
Expand All @@ -68,6 +69,7 @@
import glob
import cv2
import html
from gradio_rangeslider import RangeSlider
import re
from transformers.utils import logging
logging.set_verbosity_error
Expand Down Expand Up @@ -765,13 +767,15 @@ def ret():
model_mode = None
if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"):
gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting")
if self_refiner_setting != 0 and len(self_refiner_plan):
from shared.utils.self_refiner import normalize_self_refiner_plan
max_plans = model_def.get("self_refiner_max_plans", 1)
_, error = normalize_self_refiner_plan(self_refiner_plan, max_plans=max_plans)
if len(error):
gr.Info(error)
return ret()
if self_refiner_setting != 0:
if isinstance(self_refiner_plan, list):
max_plans = model_def.get("self_refiner_max_plans", 1)
_, error = normalize_self_refiner_plan(self_refiner_plan, max_plans=max_plans)
if len(error):
gr.Info(error)
return ret()
else:
self_refiner_plan = []

if not model_def.get("motion_amplitude", False): motion_amplitude = 1.
if "vae" in spatial_upsampling:
Expand Down Expand Up @@ -10394,8 +10398,38 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai

with gr.Column(visible = model_def.get("self_refiner", False)) as self_refiner_col:
gr.Markdown("<B>Self-Refining Video Sampling (PnP) - should improve quality of Motion</B>")
self_refiner_setting = gr.Dropdown( choices=[("Disabled", 0),("Enabled with P1-Norm", 1), ("Enabled with P2-Norm", 2), ], value=ui_get("self_refiner_setting", 0), scale = 1, label="Self Refiner", )
self_refiner_plan = gr.Textbox( value=ui_get("self_refiner_plan", ""), label="P&P Plan (start-end:steps, comma-separated)", lines=1, placeholder="2-5:3,6-13:1" )
self_refiner_setting = gr.Dropdown(choices=[("Disabled", 0),("Enabled with P1-Norm", 1), ("Enabled with P2-Norm", 2)], value=ui_get("self_refiner_setting", 0), scale=1, label="Self Refiner")

refiner_val = ensure_refiner_list(ui_get("self_refiner_plan", []))
self_refiner_plan = refiner_val if update_form else gr.State(value=refiner_val)

with gr.Group(visible=(update_form and ui_get("self_refiner_setting", 0) > 0)) as self_refiner_rules_ui:
gr.Markdown("### Refiner Plan")

with gr.Row(elem_id="refiner-input-row"):
refiner_range = RangeSlider(minimum=0, maximum=100, value=(0, 10), step=1, label="Step Range", info="Start - End", scale=3)
refiner_mult = gr.Slider(label="Iterations", value=3, minimum=1, maximum=5, step=1, scale=2)
refiner_add_btn = gr.Button("➕ Add", variant="primary", scale=0, min_width=100)

if not update_form:
refiner_add_btn.click(fn=add_refiner_rule, inputs=[self_refiner_plan, refiner_range, refiner_mult], outputs=[self_refiner_plan])
self_refiner_setting.change(fn=lambda s: gr.update(visible=s > 0), inputs=[self_refiner_setting], outputs=[self_refiner_rules_ui])

@gr.render(inputs=self_refiner_plan)
def render_refiner_plans(plans):
if not plans:
gr.Markdown("<I style='color:grey; padding: 8px;'>No plans defined. Using defaults: Steps 2-5 (3x), Steps 6-13 (1x).</I>")
return
for plan in plans:
with gr.Row(elem_classes="rule-row"):
text_display = f"Steps **{plan['start']} - {plan['end']}** : **{plan['steps']}x** iterations"
gr.Markdown(text_display, elem_classes="rule-card")
gr.Button("✖", variant="stop", scale=0, elem_classes="delete-btn").click(
fn=remove_refiner_rule,
inputs=[self_refiner_plan, gr.State(plan["id"])],
outputs=[self_refiner_plan]
)

with gr.Row():
self_refiner_f_uncertainty = gr.Slider(0.0, 1.0, value=ui_get("self_refiner_f_uncertainty", 0.1), step=0.01, label="Uncertainty Threshold", show_reset_button= False)
self_refiner_certain_percentage = gr.Slider(0.0, 1.0, value=ui_get("self_refiner_certain_percentage", 0.999), step=0.001, label="Certainty Percentage Skip", show_reset_button= False)
Expand Down